quandn2003 commited on
Commit
427d150
·
verified ·
1 Parent(s): 5f4dac7

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +6 -0
  2. .gradio/certificate.pem +31 -0
  3. LICENSE +674 -0
  4. README.md +157 -12
  5. README_DEMO.md +76 -0
  6. app.py +247 -0
  7. backbone.sh +179 -0
  8. config_ssl_upload.py +177 -0
  9. data/data_processing.ipynb +0 -0
  10. dataloaders/GenericSuperDatasetv2.py +445 -0
  11. dataloaders/ManualAnnoDatasetv2.py +756 -0
  12. dataloaders/PolypDataset.py +548 -0
  13. dataloaders/PolypTransforms.py +626 -0
  14. dataloaders/SimpleDataset.py +61 -0
  15. dataloaders/__init__.py +0 -0
  16. dataloaders/augutils.py +224 -0
  17. dataloaders/common.py +263 -0
  18. dataloaders/dataset_utils.py +128 -0
  19. dataloaders/dev_customized_med.py +250 -0
  20. dataloaders/image_transforms.py +362 -0
  21. dataloaders/niftiio.py +48 -0
  22. models/ProtoMedSAM.py +267 -0
  23. models/ProtoSAM.py +708 -0
  24. models/SamWrapper.py +68 -0
  25. models/__init__.py +0 -0
  26. models/__pycache__/ProtoSAM.cpython-312.pyc +0 -0
  27. models/__pycache__/SamWrapper.cpython-312.pyc +0 -0
  28. models/__pycache__/__init__.cpython-312.pyc +0 -0
  29. models/__pycache__/alpmodule.cpython-312.pyc +0 -0
  30. models/__pycache__/grid_proto_fewshot.cpython-312.pyc +0 -0
  31. models/alpmodule.py +199 -0
  32. models/backbone/__init__.py +0 -0
  33. models/backbone/__pycache__/__init__.cpython-312.pyc +0 -0
  34. models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc +0 -0
  35. models/backbone/torchvision_backbones.py +58 -0
  36. models/grid_proto_fewshot.py +427 -0
  37. models/segment_anything/__init__.py +15 -0
  38. models/segment_anything/__pycache__/__init__.cpython-312.pyc +0 -0
  39. models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc +0 -0
  40. models/segment_anything/__pycache__/build_sam.cpython-312.pyc +0 -0
  41. models/segment_anything/__pycache__/predictor.cpython-312.pyc +0 -0
  42. models/segment_anything/automatic_mask_generator.py +380 -0
  43. models/segment_anything/build_sam.py +107 -0
  44. models/segment_anything/modeling/__init__.py +11 -0
  45. models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc +0 -0
  46. models/segment_anything/modeling/__pycache__/common.cpython-312.pyc +0 -0
  47. models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc +0 -0
  48. models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc +0 -0
  49. models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc +0 -0
  50. models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ exps/
3
+ .vscode/
4
+ debug/
5
+ test_*/
6
+ pretrained_model/*
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
README.md CHANGED
@@ -1,12 +1,157 @@
1
- ---
2
- title: LoGoSAM Demo
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.30.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LoGoSAM_demo
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.29.0
6
+ ---
7
+ # ProtoSAM - One shot segmentation with foundational models
8
+
9
+ Link to our paper [here](https://arxiv.org/abs/2407.07042). \
10
+ This work is the successor of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning) (Link to [Paper](arxiv.org/abs/2403.03273)).
11
+
12
+ ## Demo Application
13
+
14
+ A Gradio-based demo application is now available for interactive inference with ProtoSAM. You can upload your own images and masks to test the model. See [README_DEMO.md](README_DEMO.md) for instructions on running the demo.
15
+
16
+ ## Abstract
17
+ This work introduces a new framework, ProtoSAM, for one-shot image segmentation. It combines DINOv2, a vision transformer that extracts features from images, with an Adaptive Local Prototype Pooling (ALP) layer, which generates prototypes from a support image and its mask. These prototypes are used to create an initial coarse segmentation mask by comparing the query image's features with the prototypes.
18
+ Following the extraction of an initial mask, we use numerical methods to generate prompts, such as points and bounding boxes, which are then input into the Segment Anything Model (SAM), a prompt-based segmentation model trained on natural images. This allows segmenting new classes automatically and effectively without the need for additional training.
19
+
20
+ ## How To Run
21
+ ### 1. Data preprocessing
22
+ #### 1.1 CT and MRI Dataset
23
+ Please see the notebook `data/data_processing.ipynb` for instructions.
24
+ For convenience i've compiled the data processing instructions from https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation to a single notebook. \
25
+ The CT dataset is available here: https://www.synapse.org/Synapse:syn3553734 \
26
+ The MRI dataset is availabel here: https://chaos.grand-challenge.org
27
+
28
+ run `./data/CHAOST2/dcm_img_to_nii.sh` to convert dicom images to nifti files.
29
+
30
+ #### 1.2 Polyp Dataset
31
+ Data is available here: https://www.kaggle.com/datasets/hngphmv/polypdataset?select=train.csv
32
+
33
+ Put the dataset `data/PolypDataset/`
34
+
35
+ ### 2. Running
36
+ #### 2.1 (Optional) Training and Validation of the coarse segmentation networks
37
+ ```
38
+ ./backbone.sh [MODE] [MODALITY] [LABEL_SET]
39
+ ```
40
+ MODE - validation or training \
41
+ MODALITY - ct or mri \
42
+ LABEL_SET - 0 (kidneys), 1 (liver spleen)
43
+
44
+ for example:
45
+ ```
46
+ ./backbone.sh training mri 1
47
+ ```
48
+ Please refer to `backbone.sh` for further configurations.
49
+
50
+ #### 2.1 Running ProtoSAM
51
+ Put all SAM checkpoint like sam_vit_b.pth, sam_vit_h.pth, medsam_vit_b.pth into the `pretrained_model` directory. \
52
+ Checkpoints are available at [SAM](https://github.com/facebookresearch/segment-anything) and [MedSAM](https://github.com/bowang-lab/MedSAM).
53
+
54
+ ```
55
+ ./run_protosam.sh [MODALITY] [LABEL_SET]
56
+ ```
57
+ MODALITY - ct, mri or polyp \
58
+ LABEL_SET (only relevant if doing ct or mri) - 0 (kidneys), 1 (liver spleen)
59
+ Please refer to the `run_protosam.sh` script for further configurations.
60
+
61
+
62
+ ## Acknowledgements
63
+ This work is largely based on [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation), [DINOv2](https://github.com/facebookresearch/dinov2), [SAM](https://github.com/facebookresearch/segment-anything) and is a continuation of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning).
64
+
65
+ ## Cite
66
+ If you found this repo useful, please consider giving us a citation and a star!
67
+
68
+ ```bibtex
69
+ @article{ayzenberg2024protosam,
70
+ title={ProtoSAM-One Shot Medical Image Segmentation With Foundational Models},
71
+ author={Ayzenberg, Lev and Giryes, Raja and Greenspan, Hayit},
72
+ journal={arXiv preprint arXiv:2407.07042},
73
+ year={2024}
74
+ }
75
+
76
+ @misc{ayzenberg2024dinov2,
77
+ title={DINOv2 based Self Supervised Learning For Few Shot Medical Image Segmentation},
78
+ author={Lev Ayzenberg and Raja Giryes and Hayit Greenspan},
79
+ year={2024},
80
+ eprint={2403.03273},
81
+ archivePrefix={arXiv},
82
+ primaryClass={cs.CV}
83
+ }
84
+
85
+ ```
86
+
87
+ # ProtoSAM Segmentation Demo
88
+
89
+ This Streamlit application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction.
90
+
91
+ ## Requirements
92
+
93
+ - Python 3.8 or higher
94
+ - CUDA-compatible GPU
95
+ - Required Python packages (see `requirements.txt`)
96
+
97
+ ## Setup Instructions
98
+
99
+ 1. Clone this repository:
100
+ ```bash
101
+ git clone <your-repository-url>
102
+ cd <repository-name>
103
+ ```
104
+
105
+ 2. Create and activate a virtual environment (optional but recommended):
106
+ ```bash
107
+ python -m venv venv
108
+ source venv/bin/activate # On Windows: venv\Scripts\activate
109
+ ```
110
+
111
+ 3. Install the required dependencies:
112
+ ```bash
113
+ pip install -r requirements.txt
114
+ ```
115
+
116
+ 4. Download the pretrained models:
117
+ ```bash
118
+ mkdir -p pretrained_model
119
+ # Download SAM ViT-H model
120
+ wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
121
+ mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
122
+ ```
123
+
124
+ 5. Update the model path in `app.py`:
125
+ - Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model.
126
+
127
+ ## Running the App
128
+
129
+ Start the Streamlit app with:
130
+ ```bash
131
+ streamlit run app.py
132
+ ```
133
+
134
+ This will open a browser window with the interface for the segmentation demo.
135
+
136
+ ## Usage
137
+
138
+ 1. Upload a query image (the image you want to segment)
139
+ 2. Upload a support image (an example image with a similar object)
140
+ 3. Upload a support mask (the segmentation mask for the support image)
141
+ 4. Use the sidebar to configure the model parameters if needed
142
+ 5. Click "Run Inference" to generate the segmentation result
143
+
144
+ ## Model Configuration
145
+
146
+ The app allows you to configure several model parameters via the sidebar:
147
+ - Use Bounding Box: Enable/disable bounding box input
148
+ - Use Points: Enable/disable point input
149
+ - Use Mask: Enable/disable mask input
150
+ - Use CCA: Enable/disable Connected Component Analysis
151
+ - Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement
152
+
153
+ ## Notes
154
+
155
+ - This demo requires a GPU with CUDA support
156
+ - Large images may require more GPU memory
157
+ - For optimal results, use high-quality support images and masks
README_DEMO.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ProtoSAM Segmentation Demo
2
+
3
+ This Gradio application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction.
4
+
5
+ ## Requirements
6
+
7
+ - Python 3.8 or higher
8
+ - CUDA-compatible GPU
9
+ - Required Python packages (see `requirements.txt`)
10
+
11
+ ## Setup Instructions
12
+
13
+ 1. Clone this repository:
14
+ ```bash
15
+ git clone <your-repository-url>
16
+ cd <repository-name>
17
+ ```
18
+
19
+ 2. Create and activate a virtual environment (optional but recommended):
20
+ ```bash
21
+ python -m venv venv
22
+ source venv/bin/activate # On Windows: venv\Scripts\activate
23
+ ```
24
+
25
+ 3. Install the required dependencies:
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ 4. Download the pretrained models:
31
+ ```bash
32
+ mkdir -p pretrained_model
33
+ # Download SAM ViT-H model
34
+ wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
35
+ mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
36
+ ```
37
+
38
+ 5. Update the model path in `app.py`:
39
+ - Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model.
40
+
41
+ ## Running the App
42
+
43
+ Start the app with:
44
+ ```bash
45
+ ./run_demo.sh
46
+ ```
47
+
48
+ Or run it directly with:
49
+ ```bash
50
+ python app.py
51
+ ```
52
+
53
+ This will start the server and provide a link to access the demo in your browser.
54
+
55
+ ## Usage
56
+
57
+ 1. Upload a query image (the image you want to segment)
58
+ 2. Upload a support image (an example image with a similar object)
59
+ 3. Upload a support mask (the segmentation mask for the support image)
60
+ 4. Configure the model parameters using the checkboxes
61
+ 5. Click "Run Inference" to generate the segmentation result
62
+
63
+ ## Model Configuration
64
+
65
+ The app allows you to configure several model parameters:
66
+ - Use Bounding Box: Enable/disable bounding box input
67
+ - Use Points: Enable/disable point input
68
+ - Use Mask: Enable/disable mask input
69
+ - Use CCA: Enable/disable Connected Component Analysis
70
+ - Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement
71
+
72
+ ## Notes
73
+
74
+ - This demo requires a GPU with CUDA support
75
+ - Large images may require more GPU memory
76
+ - For optimal results, use high-quality support images and masks
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ from models.ProtoSAM import ProtoSAM, ALPNetWrapper, InputFactory, TYPE_ALPNET
9
+ from models.grid_proto_fewshot import FewShotSeg
10
+ from models.segment_anything.utils.transforms import ResizeLongestSide
11
+
12
+ # Set environment variables for model caching
13
+ os.environ['TORCH_HOME'] = "./pretrained_model"
14
+
15
+ # Function to load the model
16
+ def load_model(config):
17
+ # Initial segmentation model
18
+ alpnet = FewShotSeg(
19
+ config["input_size"][0],
20
+ config["reload_model_path"],
21
+ config["model"]
22
+ )
23
+ alpnet.cuda()
24
+ base_model = ALPNetWrapper(alpnet)
25
+
26
+ # ProtoSAM model
27
+ sam_checkpoint = "pretrained_model/sam_vit_h.pth"
28
+ model = ProtoSAM(
29
+ image_size=(1024, 1024),
30
+ coarse_segmentation_model=base_model,
31
+ use_bbox=config["use_bbox"],
32
+ use_points=config["use_points"],
33
+ use_mask=config["use_mask"],
34
+ debug=False,
35
+ num_points_for_sam=1,
36
+ use_cca=config["do_cca"],
37
+ point_mode=config["point_mode"],
38
+ use_sam_trans=True,
39
+ coarse_pred_only=config["coarse_pred_only"],
40
+ sam_pretrained_path=sam_checkpoint,
41
+ use_neg_points=config["use_neg_points"],
42
+ )
43
+ model = model.to(torch.device("cuda"))
44
+ model.eval()
45
+ return model
46
+
47
+ # Function to preprocess images
48
+ def preprocess_image(image, transform):
49
+ if isinstance(image, np.ndarray):
50
+ image_np = image
51
+ else:
52
+ # Convert PIL Image to numpy array
53
+ image_np = np.array(image)
54
+
55
+ # Convert to RGB if grayscale
56
+ if len(image_np.shape) == 2:
57
+ image_np = np.stack([image_np] * 3, axis=2)
58
+ elif image_np.shape[2] == 1:
59
+ image_np = np.concatenate([image_np] * 3, axis=2)
60
+
61
+ # Apply transforms
62
+ image_tensor = transform(image_np).unsqueeze(0)
63
+ return image_tensor
64
+
65
+ # Function to create overlay visualization
66
+ def create_overlay(query_image, prediction, colormap='YlOrRd'):
67
+ """
68
+ Create an overlay of the prediction on the query image
69
+ """
70
+ # Convert tensors to numpy arrays for visualization
71
+ if isinstance(query_image, torch.Tensor):
72
+ query_image = query_image.cpu().squeeze().numpy()
73
+
74
+ if isinstance(prediction, torch.Tensor):
75
+ prediction = prediction.cpu().squeeze().numpy()
76
+
77
+ # Normalize image for visualization
78
+ query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8)
79
+
80
+ # Ensure binary mask
81
+ prediction = (prediction > 0).astype(np.float32)
82
+
83
+ # Create mask overlay
84
+ mask_cmap = plt.cm.get_cmap(colormap)
85
+ pred_rgba = mask_cmap(prediction)
86
+ pred_rgba[..., 3] = prediction * 0.7 # Set alpha channel
87
+
88
+ # Create matplotlib figure for overlay
89
+ fig, ax = plt.subplots(figsize=(10, 10))
90
+
91
+ # Handle grayscale vs RGB images
92
+ if len(query_image.shape) == 2:
93
+ ax.imshow(query_image, cmap='gray')
94
+ else:
95
+ if query_image.shape[0] == 3: # Channel-first format
96
+ query_image = np.transpose(query_image, (1, 2, 0))
97
+ ax.imshow(query_image)
98
+
99
+ ax.imshow(pred_rgba)
100
+ ax.axis('off')
101
+ plt.tight_layout()
102
+
103
+ # Convert to PIL Image
104
+ fig.canvas.draw()
105
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
106
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
107
+ plt.close(fig)
108
+
109
+ return img
110
+
111
+ # Model configuration
112
+ config = {
113
+ "input_size": [224],
114
+ "reload_model_path": "path/to/your/model.pth", # Update with your model path
115
+ "model": {"encoder": "resnet50", "decoder": "pspnet"},
116
+ "use_bbox": True,
117
+ "use_points": True,
118
+ "use_mask": True,
119
+ "do_cca": True,
120
+ "point_mode": "extreme",
121
+ "coarse_pred_only": False,
122
+ "use_neg_points": False,
123
+ "base_model": TYPE_ALPNET
124
+ }
125
+
126
+ # Function to run inference
127
+ def run_inference(query_image, support_image, support_mask, use_bbox, use_points, use_mask, use_cca, coarse_pred_only):
128
+ try:
129
+ # Update config based on user selections
130
+ config["use_bbox"] = use_bbox
131
+ config["use_points"] = use_points
132
+ config["use_mask"] = use_mask
133
+ config["do_cca"] = use_cca
134
+ config["coarse_pred_only"] = coarse_pred_only
135
+
136
+ # Check if CUDA is available
137
+ if not torch.cuda.is_available():
138
+ return None, "CUDA is not available. This demo requires GPU support."
139
+
140
+ # Load the model
141
+ model = load_model(config)
142
+
143
+ # Preprocess images
144
+ sam_trans = ResizeLongestSide(1024)
145
+
146
+ # Transform for images
147
+ transform = transforms.Compose([
148
+ transforms.ToTensor(),
149
+ transforms.Resize((1024, 1024), antialias=True),
150
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
151
+ ])
152
+
153
+ # Process query image
154
+ query_img_tensor = preprocess_image(query_image, transform)
155
+
156
+ # Process support image
157
+ support_img_tensor = preprocess_image(support_image, transform)
158
+
159
+ # Process support mask (should be binary)
160
+ support_mask_np = np.array(support_mask)
161
+ support_mask_np = (support_mask_np > 127).astype(np.float32) # Binarize mask
162
+ support_mask_tensor = torch.from_numpy(support_mask_np).unsqueeze(0).unsqueeze(0)
163
+ support_mask_tensor = torch.nn.functional.interpolate(
164
+ support_mask_tensor, size=(1024, 1024), mode='nearest'
165
+ )
166
+
167
+ # Prepare model inputs
168
+ support_images = [support_img_tensor.cuda()]
169
+ support_masks = [support_mask_tensor.cuda()]
170
+
171
+ # Create model input
172
+ coarse_model_input = InputFactory.create_input(
173
+ input_type=config["base_model"],
174
+ query_image=query_img_tensor.cuda(),
175
+ support_images=support_images,
176
+ support_labels=support_masks,
177
+ isval=True,
178
+ val_wsize=3,
179
+ original_sz=query_img_tensor.shape[-2:],
180
+ img_sz=query_img_tensor.shape[-2:],
181
+ gts=None,
182
+ )
183
+ coarse_model_input.to(torch.device("cuda"))
184
+
185
+ # Run inference
186
+ with torch.no_grad():
187
+ query_pred, scores = model(
188
+ query_img_tensor.cuda(), coarse_model_input, degrees_rotate=0
189
+ )
190
+
191
+ # Create overlay visualization
192
+ result_image = create_overlay(query_img_tensor, query_pred)
193
+
194
+ confidence_score = np.mean(scores)
195
+ return result_image, f"Confidence Score: {confidence_score:.4f}"
196
+
197
+ except Exception as e:
198
+ return None, f"Error during inference: {str(e)}"
199
+
200
+ # Define the Gradio interface
201
+ def create_interface():
202
+ with gr.Blocks(title="ProtoSAM Segmentation Demo") as demo:
203
+ gr.Markdown("# ProtoSAM Segmentation Demo")
204
+ gr.Markdown("Upload a query image, support image, and support mask to generate a segmentation prediction.")
205
+
206
+ with gr.Row():
207
+ with gr.Column():
208
+ query_image = gr.Image(label="Query Image", type="pil")
209
+ support_image = gr.Image(label="Support Image", type="pil")
210
+ support_mask = gr.Image(label="Support Mask", type="pil")
211
+
212
+ with gr.Column():
213
+ result_image = gr.Image(label="Prediction Result")
214
+ result_text = gr.Textbox(label="Result Information")
215
+
216
+ with gr.Row():
217
+ with gr.Column():
218
+ use_bbox = gr.Checkbox(label="Use Bounding Box", value=True)
219
+ use_points = gr.Checkbox(label="Use Points", value=True)
220
+ use_mask = gr.Checkbox(label="Use Mask", value=True)
221
+
222
+ with gr.Column():
223
+ use_cca = gr.Checkbox(label="Use CCA", value=True)
224
+ coarse_pred_only = gr.Checkbox(label="Coarse Prediction Only", value=False)
225
+ run_button = gr.Button("Run Inference")
226
+
227
+ run_button.click(
228
+ fn=run_inference,
229
+ inputs=[
230
+ query_image,
231
+ support_image,
232
+ support_mask,
233
+ use_bbox,
234
+ use_points,
235
+ use_mask,
236
+ use_cca,
237
+ coarse_pred_only
238
+ ],
239
+ outputs=[result_image, result_text]
240
+ )
241
+
242
+ return demo
243
+
244
+ # Create and launch the interface
245
+ if __name__ == "__main__":
246
+ demo = create_interface()
247
+ demo.launch(share=True)
backbone.sh ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ GPUID1=0
4
+ export CUDA_VISIBLE_DEVICES=$GPUID1
5
+
6
+ MODE=$1
7
+ if [ $MODE != "validation" ] && [ $MODE != "training" ]
8
+ then
9
+ echo "mode must be either validation or training"
10
+ exit 1
11
+ fi
12
+
13
+ # get modality as arg
14
+ MODALITY=$2
15
+ # make sure modality is either ct or mri
16
+ if [ $MODALITY != "ct" ] && [ $MODALITY != "mri" ]
17
+ then
18
+ echo "modality must be either ct or mri"
19
+ exit 1
20
+ fi
21
+
22
+ ####### Shared configs ######
23
+ PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training
24
+ INPUT_SIZE=256
25
+ ALL_EV=( 0 ) # 5-fold cross validation (0, 1, 2, 3, 4)
26
+ if [ $MODALITY == "ct" ]
27
+ then
28
+ DATASET='SABS_Superpix'
29
+ else
30
+ DATASET='CHAOST2_Superpix'
31
+ fi
32
+
33
+ if [ $INPUT_SIZE -gt 256 ]
34
+ then
35
+ DATASET=${DATASET}'_672'
36
+ fi
37
+
38
+ NWORKER=4
39
+ MODEL_NAME='dinov2_l14'
40
+ LORA=0
41
+ RELOAD_PATH=( "None" )
42
+ SKIP_SLICES="True"
43
+ DO_CCA="True"
44
+ TTT="False"
45
+ NSTEP=100000
46
+ RESET_AFTER_SLICE="True"
47
+ FINETUNE_ON_SUPPORT="False"
48
+ USE_SLICE_ADAPTER="False"
49
+ ADAPTER_LAYERS=1
50
+ CLAHE=False
51
+ ALL_SCALE=( "MIDDLE") # config of pseudolabels
52
+
53
+ LABEL_SETS=$3
54
+ EXCLU='[2,3]'
55
+
56
+ if [[ $MODALITY == "mri" && $LABEL_SETS -eq 1 ]]
57
+ then
58
+ echo "exluding 1, 4"
59
+ EXCLU='[1,4]' # liver(1), spleen(4)
60
+ fi
61
+
62
+ ORGANS='kidneys'
63
+ if [ $LABEL_SETS -eq 1 ]
64
+ then
65
+ ORGANS='liver_spleen'
66
+ fi
67
+
68
+
69
+ FREE_DESC=""
70
+ CPT="${MODE}_${MODEL_NAME}_${MODALITY}"
71
+ if [ -n "$FREE_DESC" ]
72
+ then
73
+ CPT="${CPT}_${FREE_DESC}"
74
+ fi
75
+
76
+ if [[ $TTT == "True" ]]
77
+ then
78
+ CPT="${CPT}_ttt_nstep_${NSTEP}"
79
+ if [ $RESET_AFTER_SLICE == "True" ]
80
+ then
81
+ CPT="${CPT}_reset_after_slice"
82
+ fi
83
+ fi
84
+
85
+ if [ $USE_SLICE_ADAPTER == "True" ]
86
+ then
87
+ CPT="${CPT}_w_adapter_${ADAPTER_LAYERS}_layers"
88
+ fi
89
+
90
+ if [ $LORA -ne 0 ]
91
+ then
92
+ CPT="${CPT}_lora_${LORA}"
93
+ fi
94
+
95
+ if [ $CLAHE == "True" ]
96
+ then
97
+ CPT="${CPT}_w_clahe"
98
+ fi
99
+
100
+ if [ $DO_CCA = "True" ]
101
+ then
102
+ CPT="${CPT}_cca"
103
+ fi
104
+
105
+ CPT="${CPT}_grid_${PROTO_GRID}_res_${INPUT_SIZE}"
106
+
107
+ if [ ${EXCLU} = "[]" ]
108
+ then
109
+ CPT="${CPT}_setting1"
110
+ else
111
+ CPT="${CPT}_setting2"
112
+ fi
113
+
114
+ CPT="${CPT}_${ORGANS}_fold"
115
+
116
+ ###### Training configs (irrelavent in testing) ######
117
+ DECAY=0.95
118
+
119
+ MAX_ITER=1000 # defines the size of an epoch
120
+ SNAPSHOT_INTERVAL=25000 # interval for saving snapshot
121
+ SEED='1234'
122
+
123
+ ###### Validation configs ######
124
+ SUPP_ID='[6]' # using the additionally loaded scan as support
125
+ if [ $MODALITY == "mri" ]
126
+ then
127
+ SUPP_ID='[4]'
128
+ fi
129
+
130
+ echo ===================================
131
+
132
+ for ((i=0; i<${#ALL_EV[@]}; i++))
133
+ do
134
+ EVAL_FOLD=${ALL_EV[i]}
135
+ CPT_W_FOLD="${CPT}_${EVAL_FOLD}"
136
+ echo $CPT_W_FOLD on GPU $GPUID1
137
+ for SUPERPIX_SCALE in "${ALL_SCALE[@]}"
138
+ do
139
+ PREFIX="test_vfold${EVAL_FOLD}"
140
+ echo $PREFIX
141
+ LOGDIR="./test_${MODALITY}/${CPT_W_FOLD}"
142
+
143
+ if [ ! -d $LOGDIR ]
144
+ then
145
+ mkdir -p $LOGDIR
146
+ fi
147
+
148
+ python3 $MODE.py with \
149
+ "modelname=$MODEL_NAME" \
150
+ 'usealign=True' \
151
+ 'optim_type=sgd' \
152
+ reload_model_path=${RELOAD_PATH[i]} \
153
+ num_workers=$NWORKER \
154
+ scan_per_load=-1 \
155
+ label_sets=$LABEL_SETS \
156
+ 'use_wce=True' \
157
+ exp_prefix=$PREFIX \
158
+ 'clsname=grid_proto' \
159
+ n_steps=$NSTEP \
160
+ exclude_cls_list=$EXCLU \
161
+ eval_fold=$EVAL_FOLD \
162
+ dataset=$DATASET \
163
+ proto_grid_size=$PROTO_GRID \
164
+ max_iters_per_load=$MAX_ITER \
165
+ min_fg_data=1 seed=$SEED \
166
+ save_snapshot_every=$SNAPSHOT_INTERVAL \
167
+ superpix_scale=$SUPERPIX_SCALE \
168
+ lr_step_gamma=$DECAY \
169
+ path.log_dir=$LOGDIR \
170
+ support_idx=$SUPP_ID \
171
+ lora=$LORA \
172
+ do_cca=$DO_CCA \
173
+ ttt=$TTT \
174
+ adapter_layers=$ADAPTER_LAYERS \
175
+ use_slice_adapter=$USE_SLICE_ADAPTER \
176
+ reset_after_slice=$RESET_AFTER_SLICE \
177
+ "input_size=($INPUT_SIZE, $INPUT_SIZE)"
178
+ done
179
+ done
config_ssl_upload.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment configuration file
3
+ Extended from config file from original PANet Repository
4
+ """
5
+ import os
6
+ import re
7
+ import glob
8
+ import itertools
9
+
10
+ import sacred
11
+ from sacred import Experiment
12
+ from sacred.observers import FileStorageObserver
13
+ from sacred.utils import apply_backspaces_and_linefeeds
14
+
15
+ from platform import node
16
+ from datetime import datetime
17
+
18
+ from util.consts import IMG_SIZE
19
+
20
+ sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
21
+ sacred.SETTINGS.CAPTURE_MODE = 'no'
22
+
23
+ ex = Experiment('mySSL')
24
+ ex.captured_out_filter = apply_backspaces_and_linefeeds
25
+
26
+ source_folders = ['.', './dataloaders', './models', './util']
27
+ sources_to_save = list(itertools.chain.from_iterable(
28
+ [glob.glob(f'{folder}/*.py') for folder in source_folders]))
29
+ for source_file in sources_to_save:
30
+ ex.add_source_file(source_file)
31
+
32
+ @ex.config
33
+ def cfg():
34
+ """Default configurations"""
35
+ seed = 1234
36
+ gpu_id = 0
37
+ mode = 'train' # for now only allows 'train'
38
+ do_validation=False
39
+ num_workers = 4 # 0 for debugging.
40
+
41
+ dataset = 'CHAOST2' # i.e. abdominal MRI
42
+ use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images
43
+
44
+ ### Training
45
+ n_steps = 100100
46
+ batch_size = 1
47
+ lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)]
48
+ lr_step_gamma = 0.95
49
+ ignore_label = 255
50
+ print_interval = 100
51
+ save_snapshot_every = 25000
52
+ max_iters_per_load = 1000 # epoch size, interval for reloading the dataset
53
+ epochs=1
54
+ scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory
55
+ which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms
56
+ input_size = (IMG_SIZE, IMG_SIZE)
57
+ min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process
58
+ label_sets = 0 # which group of labels taking as training (the rest are for testing)
59
+ curr_cls = "" # choose between rk, lk, spleen and liver
60
+ exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1
61
+ usealign = True # see vanilla PANet
62
+ use_wce = True
63
+ use_dinov2_loss = False
64
+ dice_loss = False
65
+ ### Validation
66
+ z_margin = 0
67
+ eval_fold = 0 # which fold for 5 fold cross validation
68
+ support_idx=[-1] # indicating which scan is used as support in testing.
69
+ val_wsize=2 # L_H, L_W in testing
70
+ n_sup_part = 3 # number of chuncks in testing
71
+ use_clahe = False
72
+ use_slice_adapter = False
73
+ adapter_layers=3
74
+ debug=True
75
+ skip_no_organ_slices=True
76
+ # Network
77
+ modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab
78
+ clsname = None #
79
+ reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization)
80
+ proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training
81
+ feature_hw = [input_size[0]//8, input_size[0]//8] # feature map size, should couple this with backbone in future
82
+ lora = 0
83
+ use_3_slices=False
84
+ do_cca=False
85
+ use_edge_detector=False
86
+ finetune_on_support=False
87
+ sliding_window_confidence_segmentation=False
88
+ finetune_model_on_single_slice=False
89
+ online_finetuning=True
90
+
91
+ use_bbox=True # for SAM
92
+ use_points=True # for SAM
93
+ use_mask=False # for SAM
94
+ base_model="alpnet" # or "SAM"
95
+ # SSL
96
+ superpix_scale = 'MIDDLE' #MIDDLE/ LARGE
97
+ use_pos_enc=False
98
+ support_txt_file = None # path to a txt file containing support slices
99
+ augment_support_set=False
100
+ coarse_pred_only=False # for ProtoSAM
101
+ point_mode="both" # for ProtoSAM, choose: both, conf, centroid
102
+ use_neg_points=False
103
+ n_support=1 # num support images
104
+ protosam_sam_ver="sam_h" # or medsam
105
+ grad_accumulation_steps=1
106
+ ttt=False
107
+ reset_after_slice=True # for TTT, if to reset the model after finetuning on each slice
108
+ model = {
109
+ 'align': usealign,
110
+ 'dinov2_loss': use_dinov2_loss,
111
+ 'use_coco_init': use_coco_init,
112
+ 'which_model': modelname,
113
+ 'cls_name': clsname,
114
+ 'proto_grid_size' : proto_grid_size,
115
+ 'feature_hw': feature_hw,
116
+ 'reload_model_path': reload_model_path,
117
+ 'lora': lora,
118
+ 'use_slice_adapter': use_slice_adapter,
119
+ 'adapter_layers': adapter_layers,
120
+ 'debug': debug,
121
+ 'use_pos_enc': use_pos_enc
122
+ }
123
+
124
+ task = {
125
+ 'n_ways': 1,
126
+ 'n_shots': 1,
127
+ 'n_queries': 1,
128
+ 'npart': n_sup_part
129
+ }
130
+
131
+ optim_type = 'sgd'
132
+ lr=1e-3
133
+ momentum=0.9
134
+ weight_decay=0.0005
135
+ optim = {
136
+ 'lr': lr,
137
+ 'momentum': momentum,
138
+ 'weight_decay': weight_decay
139
+ }
140
+
141
+ exp_prefix = ''
142
+
143
+ exp_str = '_'.join(
144
+ [exp_prefix]
145
+ + [dataset,]
146
+ + [f'sets_{label_sets}_{task["n_shots"]}shot'])
147
+
148
+ path = {
149
+ 'log_dir': './runs',
150
+ 'SABS':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"
151
+ },
152
+ 'SABS_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"
153
+ },
154
+ 'SABS_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"
155
+ },
156
+ 'C0':{'data_dir': "feed your dataset path here"
157
+ },
158
+ 'CHAOST2':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"
159
+ },
160
+ 'CHAOST2_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"
161
+ },
162
+ 'SABS_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"},
163
+ 'C0_Superpix':{'data_dir': "feed your dataset path here"},
164
+ 'CHAOST2_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"},
165
+ 'CHAOST2_Superpix_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"},
166
+ 'SABS_Superpix_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"},
167
+ 'SABS_Superpix_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"},
168
+ }
169
+
170
+
171
+ @ex.config_hook
172
+ def add_observer(config, command_name, logger):
173
+ """A hook fucntion to add observer"""
174
+ exp_name = f'{ex.path}_{config["exp_str"]}'
175
+ observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
176
+ ex.observers.append(observer)
177
+ return config
data/data_processing.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
dataloaders/GenericSuperDatasetv2.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset for training with pseudolabels
3
+ TODO:
4
+ 1. Merge with manual annotated dataset
5
+ 2. superpixel_scale -> superpix_config, feed like a dict
6
+ """
7
+ import glob
8
+ import numpy as np
9
+ import dataloaders.augutils as myaug
10
+ import torch
11
+ import random
12
+ import os
13
+ import copy
14
+ import platform
15
+ import json
16
+ import re
17
+ import cv2
18
+ from dataloaders.common import BaseDataset, Subset
19
+ from dataloaders.dataset_utils import*
20
+ from pdb import set_trace
21
+ from util.utils import CircularList
22
+ from util.consts import IMG_SIZE
23
+
24
+ class SuperpixelDataset(BaseDataset):
25
+ def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs):
26
+ """
27
+ Pseudolabel dataset
28
+ Args:
29
+ which_dataset: name of the dataset to use
30
+ base_dir: directory of dataset
31
+ idx_split: index of data split as we will do cross validation
32
+ mode: 'train', 'val'.
33
+ nsup: number of scans used as support. currently idle for superpixel dataset
34
+ transforms: data transform (augmentation) function
35
+ scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
36
+ num_rep: Number of augmentation applied for a same pseudolabel
37
+ tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
38
+ fix_length: fix the length of dataset
39
+ exclude_list: Labels to be excluded
40
+ superpix_scale: config of superpixels
41
+ """
42
+ super(SuperpixelDataset, self).__init__(base_dir)
43
+
44
+ self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
45
+ self.sep = DATASET_INFO[which_dataset]['_SEP']
46
+ self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME']
47
+ self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
48
+
49
+ self.image_size = image_size
50
+ self.transforms = transforms
51
+ self.is_train = True if mode == 'train' else False
52
+ self.supervised_train = supervised_train
53
+ if self.supervised_train and len(train_list) == 0:
54
+ raise Exception('Please provide training labels')
55
+ # assert mode == 'train'
56
+ self.fix_length = fix_length
57
+ if self.supervised_train:
58
+ # self.nclass = len(self.real_label_name)
59
+ self.nclass = len(self.pseu_label_name)
60
+ else:
61
+ self.nclass = len(self.pseu_label_name)
62
+ self.num_rep = num_rep
63
+ self.tile_z_dim = tile_z_dim
64
+ self.use_3_slices = use_3_slices
65
+ if tile_z_dim > 1 and self.use_3_slices:
66
+ raise Exception("tile_z_dim and use_3_slices shouldn't be used together")
67
+
68
+ # find scans in the data folder
69
+ self.nsup = nsup
70
+ self.base_dir = base_dir
71
+ self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
72
+ self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x)))
73
+
74
+ # experiment configs
75
+ self.exclude_lbs = exclude_list
76
+ self.train_list = train_list
77
+ self.superpix_scale = superpix_scale
78
+ if len(exclude_list) > 0:
79
+ print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
80
+ self.idx_split = idx_split
81
+ self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
82
+ self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
83
+ self.scan_per_load = scan_per_load
84
+
85
+ self.info_by_scan = None
86
+ self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
87
+ self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std)
88
+
89
+ if self.is_train:
90
+ if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch
91
+ self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
92
+ else: # load the entire set without a buffer
93
+ self.pid_curr_load = self.scan_ids
94
+ elif mode == 'val':
95
+ self.pid_curr_load = self.scan_ids
96
+ else:
97
+ raise Exception
98
+
99
+ self.use_clahe = False
100
+ if kwargs['use_clahe']:
101
+ self.use_clahe = True
102
+ clip_limit = 4.0 if self.img_modality == 'MR' else 2.0
103
+ self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7))
104
+
105
+ self.actual_dataset = self.read_dataset()
106
+ self.size = len(self.actual_dataset)
107
+ self.overall_slice_by_cls = self.read_classfiles()
108
+
109
+ print("###### Initial scans loaded: ######")
110
+ print(self.pid_curr_load)
111
+
112
+ def get_scanids(self, mode, idx_split):
113
+ """
114
+ Load scans by train-test split
115
+ leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one
116
+ Args:
117
+ idx_split: index for spliting cross-validation folds
118
+ """
119
+ val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
120
+ if mode == 'train':
121
+ return [ ii for ii in self.img_pids if ii not in val_ids ]
122
+ elif mode == 'val':
123
+ return val_ids
124
+
125
+ def reload_buffer(self):
126
+ """
127
+ Reload a only portion of the entire dataset, if the dataset is too large
128
+ 1. delete original buffer
129
+ 2. update self.ids_this_batch
130
+ 3. update other internel variables like __len__
131
+ """
132
+ if self.scan_per_load <= 0:
133
+ print("We are not using the reload buffer, doing notiong")
134
+ return -1
135
+
136
+ del self.actual_dataset
137
+ del self.info_by_scan
138
+
139
+ self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
140
+ self.actual_dataset = self.read_dataset()
141
+ self.size = len(self.actual_dataset)
142
+ self.update_subclass_lookup()
143
+ print(f'Loader buffer reloaded with a new size of {self.size} slices')
144
+
145
+ def organize_sample_fids(self):
146
+ out_list = {}
147
+ for curr_id in self.scan_ids:
148
+ curr_dict = {}
149
+
150
+ _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
151
+ _lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz')
152
+ _gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
153
+
154
+ curr_dict["img_fid"] = _img_fid
155
+ curr_dict["lbs_fid"] = _lb_fid
156
+ curr_dict["gt_lbs_fid"] = _gt_lb_fid
157
+ out_list[str(curr_id)] = curr_dict
158
+ return out_list
159
+
160
+ def read_dataset(self):
161
+ """
162
+ Read images into memory and store them in 2D
163
+ Build tables for the position of an individual 2D slice in the entire dataset
164
+ """
165
+ out_list = []
166
+ self.scan_z_idx = {}
167
+ self.info_by_scan = {} # meta data of each scan
168
+ glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
169
+
170
+ for scan_id, itm in self.img_lb_fids.items():
171
+ if scan_id not in self.pid_curr_load:
172
+ continue
173
+
174
+ img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
175
+ # read connected graph of labels
176
+ if self.use_clahe:
177
+ # img = nself.clahe.apply(img.astype(np.uint8))
178
+ if self.img_modality == 'MR':
179
+ img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0)
180
+ img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
181
+
182
+ img = img.transpose(1,2,0)
183
+ self.info_by_scan[scan_id] = _info
184
+
185
+ img = np.float32(img)
186
+ img = self.norm_func(img)
187
+ self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
188
+
189
+ if self.supervised_train:
190
+ lb = read_nii_bysitk(itm["gt_lbs_fid"])
191
+ else:
192
+ lb = read_nii_bysitk(itm["lbs_fid"])
193
+ lb = lb.transpose(1,2,0)
194
+ lb = np.int32(lb)
195
+
196
+ # resize img and lb to self.image_size
197
+ img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
198
+ lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
199
+
200
+ # format of slices: [axial_H x axial_W x Z]
201
+ if self.supervised_train:
202
+ # remove all images that dont have the training labels
203
+ del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))]
204
+ # create an new img and lb without indices in del_indices
205
+ new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices)))
206
+ new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices)))
207
+ new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)]
208
+ new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)]
209
+
210
+ img = new_img
211
+ lb = new_lb
212
+ a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0]
213
+
214
+ nframes = img.shape[-1]
215
+ assert img.shape[-1] == lb.shape[-1]
216
+ base_idx = img.shape[-1] // 2 # index of the middle slice
217
+
218
+ # re-organize 3D images into 2D slices and record essential information for each slice
219
+ out_list.append( {"img": img[..., 0: 1],
220
+ "lb":lb[..., 0: 0 + 1],
221
+ "sup_max_cls": lb[..., 0: 0 + 1].max(),
222
+ "is_start": True,
223
+ "is_end": False,
224
+ "nframe": nframes,
225
+ "scan_id": scan_id,
226
+ "z_id":0,
227
+ })
228
+
229
+ self.scan_z_idx[scan_id][0] = glb_idx
230
+ glb_idx += 1
231
+
232
+ for ii in range(1, img.shape[-1] - 1):
233
+ out_list.append( {"img": img[..., ii: ii + 1],
234
+ "lb":lb[..., ii: ii + 1],
235
+ "is_start": False,
236
+ "is_end": False,
237
+ "sup_max_cls": lb[..., ii: ii + 1].max(),
238
+ "nframe": nframes,
239
+ "scan_id": scan_id,
240
+ "z_id": ii,
241
+ })
242
+ self.scan_z_idx[scan_id][ii] = glb_idx
243
+ glb_idx += 1
244
+
245
+ ii += 1 # last slice of a 3D volume
246
+ out_list.append( {"img": img[..., ii: ii + 1],
247
+ "lb":lb[..., ii: ii+ 1],
248
+ "is_start": False,
249
+ "is_end": True,
250
+ "sup_max_cls": lb[..., ii: ii + 1].max(),
251
+ "nframe": nframes,
252
+ "scan_id": scan_id,
253
+ "z_id": ii,
254
+ })
255
+
256
+ self.scan_z_idx[scan_id][ii] = glb_idx
257
+ glb_idx += 1
258
+
259
+ return out_list
260
+
261
+ def read_classfiles(self):
262
+ """
263
+ Load the scan-slice-class indexing file
264
+ """
265
+ with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
266
+ cls_map = json.load( fopen)
267
+ fopen.close()
268
+
269
+ with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
270
+ self.tp1_cls_map = json.load( fopen)
271
+ fopen.close()
272
+
273
+ return cls_map
274
+
275
+ def get_superpixels_similarity(self, sp1, sp2):
276
+ pass
277
+
278
+ def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None):
279
+ if bi_val is None:
280
+ # bi_val = np.random.randint(1, sup_max_cls)
281
+ bi_val = random.choice(list(np.unique(super_map)))
282
+ if conn_graph is not None and img is not None:
283
+ # get number of neighbors of bi_val
284
+ neighbors = conn_graph[bi_val]
285
+ # pick a random number of neighbors and merge them
286
+ n_neighbors = np.random.randint(0, len(neighbors))
287
+ try:
288
+ neighbors = random.sample(neighbors, n_neighbors)
289
+ except TypeError:
290
+ neighbors = []
291
+ # merge neighbors
292
+ super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map)
293
+ return np.float32(super_map == bi_val)
294
+
295
+ def supcls_pick(self, super_map):
296
+ return random.choice(list(np.unique(super_map)))
297
+
298
+ def get_3_slice_adjacent_image(self, image_t, index):
299
+ curr_dict = self.actual_dataset[index]
300
+ prev_image = np.zeros_like(image_t)
301
+
302
+ if index > 1 and not curr_dict["is_start"]:
303
+ prev_dict = self.actual_dataset[index - 1]
304
+ prev_image = prev_dict["img"]
305
+
306
+ next_image = np.zeros_like(image_t)
307
+ if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
308
+ next_dict = self.actual_dataset[index + 1]
309
+ next_image = next_dict["img"]
310
+
311
+ image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
312
+
313
+ return image_t
314
+
315
+ def __getitem__(self, index):
316
+ index = index % len(self.actual_dataset)
317
+ curr_dict = self.actual_dataset[index]
318
+ sup_max_cls = curr_dict['sup_max_cls']
319
+ if sup_max_cls < 1:
320
+ return self.__getitem__(index + 1)
321
+
322
+ image_t = curr_dict["img"]
323
+ label_raw = curr_dict["lb"]
324
+
325
+ if self.use_3_slices:
326
+ image_t = self.get_3_slice_adjacent_image(image_t, index)
327
+
328
+ for _ex_cls in self.exclude_lbs:
329
+ if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen
330
+ return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
331
+
332
+ if self.supervised_train:
333
+ superpix_label = -1
334
+ label_t = np.float32(label_raw)
335
+
336
+ lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list)))
337
+ label_t[label_t != lb_id] = 0
338
+ label_t[label_t == lb_id] = 1
339
+
340
+ else:
341
+ superpix_label = self.supcls_pick(label_raw)
342
+ label_t = np.float32(label_raw == superpix_label)
343
+
344
+ pair_buffer = []
345
+
346
+ comp = np.concatenate( [image_t, label_t], axis = -1 )
347
+
348
+ for ii in range(self.num_rep):
349
+ if self.transforms is not None:
350
+ img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False)
351
+ else:
352
+ img, lb = comp[:, :, 0:1], comp[:, :, 1:2]
353
+ # if ii % 2 == 0:
354
+ # label_raw = lb
355
+ # lb = lb == superpix_label
356
+
357
+ img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float()
358
+ lb = torch.from_numpy( lb.squeeze(-1)).float()
359
+
360
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
361
+
362
+ is_start = curr_dict["is_start"]
363
+ is_end = curr_dict["is_end"]
364
+ nframe = np.int32(curr_dict["nframe"])
365
+ scan_id = curr_dict["scan_id"]
366
+ z_id = curr_dict["z_id"]
367
+
368
+ sample = {"image": img,
369
+ "label":lb,
370
+ "is_start": is_start,
371
+ "is_end": is_end,
372
+ "nframe": nframe,
373
+ "scan_id": scan_id,
374
+ "z_id": z_id
375
+ }
376
+
377
+ # Add auxiliary attributes
378
+ if self.aux_attrib is not None:
379
+ for key_prefix in self.aux_attrib:
380
+ # Process the data sample, create new attributes and save them in a dictionary
381
+ aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
382
+ for key_suffix in aux_attrib_val:
383
+ # one function may create multiple attributes, so we need suffix to distinguish them
384
+ sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
385
+ pair_buffer.append(sample)
386
+
387
+ support_images = []
388
+ support_mask = []
389
+ support_class = []
390
+
391
+ query_images = []
392
+ query_labels = []
393
+ query_class = []
394
+
395
+ for idx, itm in enumerate(pair_buffer):
396
+ if idx % 2 == 0:
397
+ support_images.append(itm["image"])
398
+ support_class.append(1) # pseudolabel class
399
+ support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] ))
400
+ else:
401
+ query_images.append(itm["image"])
402
+ query_class.append(1)
403
+ query_labels.append( itm["label"])
404
+
405
+ return {'class_ids': [support_class],
406
+ 'support_images': [support_images], #
407
+ 'superpix_label': superpix_label,
408
+ 'superpix_label_raw': label_raw[:,:,0],
409
+ 'support_mask': [support_mask],
410
+ 'query_images': query_images, #
411
+ 'query_labels': query_labels,
412
+ 'scan_id': scan_id,
413
+ 'z_id': z_id,
414
+ 'nframe': nframe,
415
+ }
416
+
417
+
418
+ def __len__(self):
419
+ """
420
+ copy-paste from basic naive dataset configuration
421
+ """
422
+ if self.fix_length != None:
423
+ assert self.fix_length >= len(self.actual_dataset)
424
+ return self.fix_length
425
+ else:
426
+ return len(self.actual_dataset)
427
+
428
+ def getMaskMedImg(self, label, class_id, class_ids):
429
+ """
430
+ Generate FG/BG mask from the segmentation mask
431
+
432
+ Args:
433
+ label: semantic mask
434
+ class_id: semantic class of interest
435
+ class_ids: all class id in this episode
436
+ """
437
+ fg_mask = torch.where(label == class_id,
438
+ torch.ones_like(label), torch.zeros_like(label))
439
+ bg_mask = torch.where(label != class_id,
440
+ torch.ones_like(label), torch.zeros_like(label))
441
+ for class_id in class_ids:
442
+ bg_mask[label == class_id] = 0
443
+
444
+ return {'fg_mask': fg_mask,
445
+ 'bg_mask': bg_mask}
dataloaders/ManualAnnoDatasetv2.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manually labeled dataset
3
+ TODO:
4
+ 1. Merge with superpixel dataset
5
+ """
6
+ import glob
7
+ import numpy as np
8
+ import dataloaders.augutils as myaug
9
+ import torch
10
+ import random
11
+ import os
12
+ import copy
13
+ import platform
14
+ import json
15
+ import re
16
+ import cv2
17
+ from dataloaders.common import BaseDataset, Subset, ValidationDataset
18
+ # from common import BaseDataset, Subset
19
+ from dataloaders.dataset_utils import*
20
+ from pdb import set_trace
21
+ from util.utils import CircularList
22
+ from util.consts import IMG_SIZE
23
+
24
+ MODE_DEFAULT = "default"
25
+ MODE_FULL_SCAN = "full_scan"
26
+
27
+ class ManualAnnoDataset(BaseDataset):
28
+ def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None, **kwargs):
29
+ """
30
+ Manually labeled dataset
31
+ Args:
32
+ which_dataset: name of the dataset to use
33
+ base_dir: directory of dataset
34
+ idx_split: index of data split as we will do cross validation
35
+ mode: 'train', 'val'.
36
+ transforms: data transform (augmentation) function
37
+ min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset
38
+ scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
39
+ tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
40
+ nsup: number of support scans
41
+ fix_length: fix the length of dataset
42
+ exclude_list: Labels to be excluded
43
+ extern_normalize_function: normalization function used for data pre-processing
44
+ """
45
+ super(ManualAnnoDataset, self).__init__(base_dir)
46
+ self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
47
+ self.sep = DATASET_INFO[which_dataset]['_SEP']
48
+ self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
49
+ self.image_size = image_size
50
+ self.transforms = transforms
51
+ self.is_train = True if mode == 'train' else False
52
+ self.phase = mode
53
+ self.fix_length = fix_length
54
+ self.all_label_names = self.label_name
55
+ self.nclass = len(self.label_name)
56
+ self.tile_z_dim = tile_z_dim
57
+ self.base_dir = base_dir
58
+ self.nsup = nsup
59
+ self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
60
+ self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds
61
+ if 'use_clahe' not in kwargs:
62
+ self.use_clahe = False
63
+ else:
64
+ self.use_clahe = kwargs['use_clahe']
65
+ if self.use_clahe:
66
+ self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7,7))
67
+
68
+ self.use_3_slices = kwargs["use_3_slices"] if 'use_3_slices' in kwargs else False
69
+ if self.use_3_slices:
70
+ self.tile_z_dim=1
71
+
72
+ self.get_item_mode = MODE_DEFAULT
73
+ if 'get_item_mode' in kwargs:
74
+ self.get_item_mode = kwargs['get_item_mode']
75
+
76
+ self.exclude_lbs = exclude_list
77
+ if len(exclude_list) > 0:
78
+ print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
79
+
80
+ self.idx_split = idx_split
81
+ self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
82
+ self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
83
+
84
+ self.scan_per_load = scan_per_load
85
+
86
+ self.info_by_scan = None
87
+ self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
88
+
89
+ if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset.
90
+ self.norm_func = extern_normalize_func
91
+ print(f'###### Dataset: using external normalization statistics ######')
92
+ else:
93
+ self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()])
94
+ print(f'###### Dataset: using normalization statistics calculated from loaded data ######')
95
+
96
+ if self.is_train:
97
+ if scan_per_load > 0: # buffer needed
98
+ self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
99
+ else: # load the entire set without a buffer
100
+ self.pid_curr_load = self.scan_ids
101
+ elif mode == 'val':
102
+ self.pid_curr_load = self.scan_ids
103
+ self.potential_support_sid = []
104
+ else:
105
+ raise Exception
106
+ self.actual_dataset = self.read_dataset()
107
+ self.size = len(self.actual_dataset)
108
+ self.overall_slice_by_cls = self.read_classfiles()
109
+ self.update_subclass_lookup()
110
+
111
+ def get_scanids(self, mode, idx_split):
112
+ val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
113
+ self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index
114
+ if mode == 'train':
115
+ return [ ii for ii in self.img_pids if ii not in val_ids ]
116
+ elif mode == 'val':
117
+ return val_ids
118
+
119
+ def reload_buffer(self):
120
+ """
121
+ Reload a portion of the entire dataset, if the dataset is too large
122
+ 1. delete original buffer
123
+ 2. update self.ids_this_batch
124
+ 3. update other internel variables like __len__
125
+ """
126
+ if self.scan_per_load <= 0:
127
+ print("We are not using the reload buffer, doing notiong")
128
+ return -1
129
+
130
+ del self.actual_dataset
131
+ del self.info_by_scan
132
+ self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
133
+ self.actual_dataset = self.read_dataset()
134
+ self.size = len(self.actual_dataset)
135
+ self.update_subclass_lookup()
136
+ print(f'Loader buffer reloaded with a new size of {self.size} slices')
137
+
138
+ def organize_sample_fids(self):
139
+ out_list = {}
140
+ for curr_id in self.scan_ids:
141
+ curr_dict = {}
142
+
143
+ _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
144
+ _lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
145
+
146
+ curr_dict["img_fid"] = _img_fid
147
+ curr_dict["lbs_fid"] = _lb_fid
148
+ out_list[str(curr_id)] = curr_dict
149
+ return out_list
150
+
151
+ def read_dataset(self):
152
+ """
153
+ Build index pointers to individual slices
154
+ Also keep a look-up table from scan_id, slice to index
155
+ """
156
+ out_list = []
157
+ self.scan_z_idx = {}
158
+ self.info_by_scan = {} # meta data of each scan
159
+ glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
160
+
161
+ for scan_id, itm in self.img_lb_fids.items():
162
+ if scan_id not in self.pid_curr_load:
163
+ continue
164
+
165
+ img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
166
+
167
+ img = img.transpose(1,2,0)
168
+
169
+ self.info_by_scan[scan_id] = _info
170
+
171
+ if self.use_clahe:
172
+ img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
173
+
174
+ img = np.float32(img)
175
+ img = self.norm_func(img)
176
+
177
+ self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
178
+
179
+ lb = read_nii_bysitk(itm["lbs_fid"])
180
+ lb = lb.transpose(1,2,0)
181
+
182
+ lb = np.float32(lb)
183
+
184
+ img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
185
+ lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
186
+
187
+ assert img.shape[-1] == lb.shape[-1]
188
+ base_idx = img.shape[-1] // 2 # index of the middle slice
189
+
190
+ # write the beginning frame
191
+ out_list.append( {"img": img[..., 0: 1],
192
+ "lb":lb[..., 0: 0 + 1],
193
+ "is_start": True,
194
+ "is_end": False,
195
+ "nframe": img.shape[-1],
196
+ "scan_id": scan_id,
197
+ "z_id":0})
198
+
199
+ self.scan_z_idx[scan_id][0] = glb_idx
200
+ glb_idx += 1
201
+
202
+ for ii in range(1, img.shape[-1] - 1):
203
+ out_list.append( {"img": img[..., ii: ii + 1],
204
+ "lb":lb[..., ii: ii + 1],
205
+ "is_start": False,
206
+ "is_end": False,
207
+ "nframe": -1,
208
+ "scan_id": scan_id,
209
+ "z_id": ii
210
+ })
211
+ self.scan_z_idx[scan_id][ii] = glb_idx
212
+ glb_idx += 1
213
+
214
+ ii += 1 # last frame, note the is_end flag
215
+ out_list.append( {"img": img[..., ii: ii + 1],
216
+ "lb":lb[..., ii: ii+ 1],
217
+ "is_start": False,
218
+ "is_end": True,
219
+ "nframe": -1,
220
+ "scan_id": scan_id,
221
+ "z_id": ii
222
+ })
223
+
224
+ self.scan_z_idx[scan_id][ii] = glb_idx
225
+ glb_idx += 1
226
+
227
+ return out_list
228
+
229
+ def read_classfiles(self):
230
+ with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
231
+ cls_map = json.load( fopen)
232
+ fopen.close()
233
+
234
+ with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
235
+ self.tp1_cls_map = json.load( fopen)
236
+ fopen.close()
237
+
238
+ return cls_map
239
+
240
+ def __getitem__(self, index):
241
+ if self.get_item_mode == MODE_DEFAULT:
242
+ return self.__getitem_default__(index)
243
+ elif self.get_item_mode == MODE_FULL_SCAN:
244
+ return self.__get_ct_scan___(index)
245
+ else:
246
+ raise NotImplementedError("Unknown mode")
247
+
248
+
249
+ def __get_ct_scan___(self, index):
250
+ scan_n = index % len(self.scan_z_idx)
251
+ scan_id = list(self.scan_z_idx.keys())[scan_n]
252
+ scan_slices = self.scan_z_idx[scan_id]
253
+
254
+ scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
255
+
256
+ scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
257
+
258
+ scan_imgs = np.float32(scan_imgs)
259
+ scan_lbs = np.float32(scan_lbs)
260
+
261
+ scan_imgs = torch.from_numpy(scan_imgs).unsqueeze(0)
262
+ scan_lbs = torch.from_numpy(scan_lbs)
263
+
264
+ if self.tile_z_dim:
265
+ scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1)
266
+ assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}'
267
+
268
+ # # reshape to C, D, H, W
269
+ # scan_imgs = scan_imgs.permute(1, 0, 2, 3)
270
+ # scan_lbs = scan_lbs.permute(1, 0, 2, 3)
271
+
272
+ sample = {"image": scan_imgs,
273
+ "label":scan_lbs,
274
+ "scan_id": scan_id,
275
+ }
276
+
277
+ return sample
278
+
279
+
280
+ def get_3_slice_adjacent_image(self, image_t, index):
281
+ curr_dict = self.actual_dataset[index]
282
+ prev_image = np.zeros_like(image_t)
283
+
284
+ if index > 1 and not curr_dict["is_start"]:
285
+ prev_dict = self.actual_dataset[index - 1]
286
+ prev_image = prev_dict["img"]
287
+
288
+ next_image = np.zeros_like(image_t)
289
+ if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
290
+ next_dict = self.actual_dataset[index + 1]
291
+ next_image = next_dict["img"]
292
+
293
+ image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
294
+
295
+ return image_t
296
+
297
+
298
+ def __getitem_default__(self, index):
299
+ index = index % len(self.actual_dataset)
300
+ curr_dict = self.actual_dataset[index]
301
+ if self.is_train:
302
+ if len(self.exclude_lbs) > 0:
303
+ for _ex_cls in self.exclude_lbs:
304
+ if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen
305
+ return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
306
+
307
+ comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 )
308
+ if self.transforms is not None:
309
+ img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False)
310
+ else:
311
+ raise Exception("No transform function is provided")
312
+
313
+ else:
314
+ img = curr_dict['img']
315
+ lb = curr_dict['lb']
316
+
317
+
318
+ img = np.float32(img)
319
+ lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
320
+ if self.use_3_slices:
321
+ img = self.get_3_slice_adjacent_image(img, index)
322
+
323
+ img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
324
+ lb = torch.from_numpy( lb)
325
+
326
+ if self.tile_z_dim:
327
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
328
+ assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
329
+
330
+ is_start = curr_dict["is_start"]
331
+ is_end = curr_dict["is_end"]
332
+ nframe = np.int32(curr_dict["nframe"])
333
+ scan_id = curr_dict["scan_id"]
334
+ z_id = curr_dict["z_id"]
335
+
336
+ sample = {"image": img,
337
+ "label":lb,
338
+ "is_start": is_start,
339
+ "is_end": is_end,
340
+ "nframe": nframe,
341
+ "scan_id": scan_id,
342
+ "z_id": z_id
343
+ }
344
+ # Add auxiliary attributes
345
+ if self.aux_attrib is not None:
346
+ for key_prefix in self.aux_attrib:
347
+ # Process the data sample, create new attributes and save them in a dictionary
348
+ aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
349
+ for key_suffix in aux_attrib_val:
350
+ # one function may create multiple attributes, so we need suffix to distinguish them
351
+ sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
352
+
353
+ return sample
354
+
355
+ def __len__(self):
356
+ """
357
+ copy-paste from basic naive dataset configuration
358
+ """
359
+ if self.get_item_mode == MODE_FULL_SCAN:
360
+ return len(self.scan_z_idx)
361
+
362
+ if self.fix_length != None:
363
+ assert self.fix_length >= len(self.actual_dataset)
364
+ return self.fix_length
365
+ else:
366
+ return len(self.actual_dataset)
367
+
368
+ def update_subclass_lookup(self):
369
+ """
370
+ Updating the class-slice indexing list
371
+ Args:
372
+ [internal] overall_slice_by_cls:
373
+ {
374
+ class1: {pid1: [slice1, slice2, ....],
375
+ pid2: [slice1, slice2]},
376
+ ...}
377
+ class2:
378
+ ...
379
+ }
380
+ out[internal]:
381
+ {
382
+ class1: [ idx1, idx2, ... ],
383
+ class2: [ idx1, idx2, ... ],
384
+ ...
385
+ }
386
+
387
+ """
388
+ # delete previous ones if any
389
+ assert self.overall_slice_by_cls is not None
390
+
391
+ if not hasattr(self, 'idx_by_class'):
392
+ self.idx_by_class = {}
393
+ # filter the new one given the actual list
394
+ for cls in self.label_name:
395
+ if cls not in self.idx_by_class.keys():
396
+ self.idx_by_class[cls] = []
397
+ else:
398
+ del self.idx_by_class[cls][:]
399
+ for cls, dict_by_pid in self.overall_slice_by_cls.items():
400
+ for pid, slice_list in dict_by_pid.items():
401
+ if pid not in self.pid_curr_load:
402
+ continue
403
+ self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ]
404
+ print("###### index-by-class table has been reloaded ######")
405
+
406
+ def getMaskMedImg(self, label, class_id, class_ids):
407
+ """
408
+ Generate FG/BG mask from the segmentation mask. Used when getting the support
409
+ """
410
+ # Dense Mask
411
+ fg_mask = torch.where(label == class_id,
412
+ torch.ones_like(label), torch.zeros_like(label))
413
+ bg_mask = torch.where(label != class_id,
414
+ torch.ones_like(label), torch.zeros_like(label))
415
+ for class_id in class_ids:
416
+ bg_mask[label == class_id] = 0
417
+
418
+ return {'fg_mask': fg_mask,
419
+ 'bg_mask': bg_mask}
420
+
421
+ def subsets(self, sub_args_lst=None):
422
+ """
423
+ Override base-class subset method
424
+ Create subsets by scan_ids
425
+
426
+ output: list [[<fid in each class>] <class1>, <class2> ]
427
+ """
428
+
429
+ if sub_args_lst is not None:
430
+ subsets = []
431
+ ii = 0
432
+ for cls_name, index_list in self.idx_by_class.items():
433
+ subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) )
434
+ ii += 1
435
+ else:
436
+ subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()]
437
+ return subsets
438
+
439
+ def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int):
440
+ """
441
+ getting (probably multi-shot) support set for evaluation
442
+ sample from 50% (1shot) or 20 35 50 65 80 (5shot)
443
+ Args:
444
+ curr_cls: current class to segment, starts from 1
445
+ class_idx: a list of all foreground class in nways, starts from 1
446
+ npart: how may chunks used to split the support
447
+ scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan
448
+ being served as support, in self.pid_curr_load
449
+ """
450
+ assert npart % 2 == 1
451
+ assert curr_class != 0; assert 0 not in class_idx
452
+ # assert not self.is_train
453
+
454
+ self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
455
+ # print(f'###### Using {len(scan_idx)} shot evaluation!')
456
+
457
+ if npart == 1:
458
+ pcts = [0.5]
459
+ else:
460
+ half_part = 1 / (npart * 2)
461
+ part_interval = (1.0 - 1.0 / npart) / (npart - 1)
462
+ pcts = [ half_part + part_interval * ii for ii in range(npart) ]
463
+
464
+ # print(f'###### Parts percentage: {pcts} ######')
465
+
466
+ # norm_func = get_normalize_op(modality='MR', fids=None)
467
+ out_buffer = [] # [{scanid, img, lb}]
468
+ for _part in range(npart):
469
+ concat_buffer = [] # for each fold do a concat in image and mask in batch dimension
470
+ for scan_order in scan_idx:
471
+ _scan_id = self.pid_curr_load[ scan_order ]
472
+ print(f'Using scan {_scan_id} as support!')
473
+
474
+ # for _pc in pcts:
475
+ _zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices
476
+ _zid = _zlist[int(pcts[_part] * len(_zlist))]
477
+ _glb_idx = self.scan_z_idx[_scan_id][_zid]
478
+
479
+ # almost copy-paste __getitem__ but no augmentation
480
+ curr_dict = self.actual_dataset[_glb_idx]
481
+ img = curr_dict['img']
482
+ lb = curr_dict['lb']
483
+
484
+ if self.use_3_slices:
485
+ prev_image = np.zeros_like(img)
486
+ if _glb_idx > 1 and not curr_dict["is_start"]:
487
+ prev_dict = self.actual_dataset[_glb_idx - 1]
488
+ prev_image = prev_dict["img"]
489
+
490
+ next_image = np.zeros_like(img)
491
+ if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
492
+ next_dict = self.actual_dataset[_glb_idx + 1]
493
+ next_image = next_dict["img"]
494
+
495
+ img = np.concatenate([prev_image, img, next_image], axis=-1)
496
+
497
+ img = np.float32(img)
498
+ lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
499
+
500
+ img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
501
+ lb = torch.from_numpy( lb )
502
+
503
+ if self.tile_z_dim:
504
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
505
+ assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
506
+
507
+ is_start = curr_dict["is_start"]
508
+ is_end = curr_dict["is_end"]
509
+ nframe = np.int32(curr_dict["nframe"])
510
+ scan_id = curr_dict["scan_id"]
511
+ z_id = curr_dict["z_id"]
512
+
513
+ sample = {"image": img,
514
+ "label":lb,
515
+ "is_start": is_start,
516
+ "inst": None,
517
+ "scribble": None,
518
+ "is_end": is_end,
519
+ "nframe": nframe,
520
+ "scan_id": scan_id,
521
+ "z_id": z_id
522
+ }
523
+
524
+ concat_buffer.append(sample)
525
+ out_buffer.append({
526
+ "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0),
527
+ "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0),
528
+
529
+ })
530
+
531
+ # do the concat, and add to output_buffer
532
+
533
+ # post-processing, including keeping the foreground and suppressing background.
534
+ support_images = []
535
+ support_mask = []
536
+ support_class = []
537
+ for itm in out_buffer:
538
+ support_images.append(itm["image"])
539
+ support_class.append(curr_class)
540
+ support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx ))
541
+
542
+ return {'class_ids': [support_class],
543
+ 'support_images': [support_images], #
544
+ 'support_mask': [support_mask],
545
+ }
546
+
547
+ def get_support_scan(self, curr_class: int, class_idx: list, scan_idx: list):
548
+ self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
549
+ # print(f'###### Using {len(scan_idx)} shot evaluation!')
550
+ scan_slices = self.scan_z_idx[self.potential_support_sid[0]]
551
+ scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
552
+
553
+ scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
554
+ # binarize the labels
555
+ scan_lbs[scan_lbs != curr_class] = 0
556
+ scan_lbs[scan_lbs == curr_class] = 1
557
+
558
+ scan_imgs = torch.from_numpy(np.float32(scan_imgs)).unsqueeze(0)
559
+ scan_lbs = torch.from_numpy(np.float32(scan_lbs))
560
+
561
+ if self.tile_z_dim:
562
+ scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1)
563
+ assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}'
564
+
565
+ # reshape to C, D, H, W
566
+ sample = {"scan": scan_imgs,
567
+ "labels":scan_lbs,
568
+ }
569
+
570
+ return sample
571
+
572
+
573
+ def get_support_multiple_classes(self, classes: list, scan_idx: list, npart: int, use_3_slices=False):
574
+ """
575
+ getting (probably multi-shot) support set for evaluation
576
+ sample from 50% (1shot) or 20 35 50 65 80 (5shot)
577
+ Args:
578
+ curr_cls: current class to segment, starts from 1
579
+ class_idx: a list of all foreground class in nways, starts from 1
580
+ npart: how may chunks used to split the support
581
+ scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan
582
+ being served as support, in self.pid_curr_load
583
+ """
584
+ assert npart % 2 == 1
585
+ # assert curr_class != 0; assert 0 not in class_idx
586
+ # assert not self.is_train
587
+
588
+ self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
589
+ # print(f'###### Using {len(scan_idx)} shot evaluation!')
590
+
591
+ if npart == 1:
592
+ pcts = [0.5]
593
+ else:
594
+ half_part = 1 / (npart * 2)
595
+ part_interval = (1.0 - 1.0 / npart) / (npart - 1)
596
+ pcts = [ half_part + part_interval * ii for ii in range(npart) ]
597
+
598
+ # print(f'###### Parts percentage: {pcts} ######')
599
+
600
+ out_buffer = [] # [{scanid, img, lb}]
601
+ for _part in range(npart):
602
+ concat_buffer = [] # for each fold do a concat in image and mask in batch dimension
603
+ for scan_order in scan_idx:
604
+ _scan_id = self.pid_curr_load[ scan_order ]
605
+ print(f'Using scan {_scan_id} as support!')
606
+
607
+ # for _pc in pcts:
608
+ zlist = []
609
+ for curr_class in classes:
610
+ zlist.append(self.tp1_cls_map[self.label_name[curr_class]][_scan_id]) # list of indices
611
+ # merge all the lists in zlist and keep only the unique elements
612
+ # _zlist = sorted(list(set([item for sublist in zlist for item in sublist])))
613
+ # take only the indices that appear in all of the sublist
614
+ _zlist = sorted(list(set.intersection(*map(set, zlist))))
615
+ _zid = _zlist[int(pcts[_part] * len(_zlist))]
616
+ _glb_idx = self.scan_z_idx[_scan_id][_zid]
617
+
618
+ # almost copy-paste __getitem__ but no augmentation
619
+ curr_dict = self.actual_dataset[_glb_idx]
620
+ img = curr_dict['img']
621
+ lb = curr_dict['lb']
622
+
623
+ if use_3_slices:
624
+ prev_image = np.zeros_like(img)
625
+ if _glb_idx > 1 and not curr_dict["is_start"]:
626
+ prev_dict = self.actual_dataset[_glb_idx - 1]
627
+ assert prev_dict["scan_id"] == curr_dict["scan_id"]
628
+ assert prev_dict["z_id"] == curr_dict["z_id"] - 1
629
+ prev_image = prev_dict["img"]
630
+
631
+ next_image = np.zeros_like(img)
632
+ if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
633
+ next_dict = self.actual_dataset[_glb_idx + 1]
634
+ assert next_dict["scan_id"] == curr_dict["scan_id"]
635
+ assert next_dict["z_id"] == curr_dict["z_id"] + 1
636
+ next_image = next_dict["img"]
637
+
638
+ img = np.concatenate([prev_image, img, next_image], axis=-1)
639
+
640
+ img = np.float32(img)
641
+ lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
642
+ # zero all labels that are not in the classes arg
643
+ mask = np.zeros_like(lb)
644
+ for cls in classes:
645
+ mask[lb == cls] = 1
646
+ lb[~mask.astype(np.bool)] = 0
647
+
648
+ img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
649
+ lb = torch.from_numpy( lb )
650
+
651
+ if self.tile_z_dim:
652
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
653
+ assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
654
+
655
+ is_start = curr_dict["is_start"]
656
+ is_end = curr_dict["is_end"]
657
+ nframe = np.int32(curr_dict["nframe"])
658
+ scan_id = curr_dict["scan_id"]
659
+ z_id = curr_dict["z_id"]
660
+
661
+ sample = {"image": img,
662
+ "label":lb,
663
+ "is_start": is_start,
664
+ "inst": None,
665
+ "scribble": None,
666
+ "is_end": is_end,
667
+ "nframe": nframe,
668
+ "scan_id": scan_id,
669
+ "z_id": z_id
670
+ }
671
+
672
+ concat_buffer.append(sample)
673
+ out_buffer.append({
674
+ "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0),
675
+ "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0),
676
+
677
+ })
678
+
679
+ # do the concat, and add to output_buffer
680
+
681
+ # post-processing, including keeping the foreground and suppressing background.
682
+ support_images = []
683
+ support_mask = []
684
+ support_class = []
685
+ for itm in out_buffer:
686
+ support_images.append(itm["image"])
687
+ support_class.append(curr_class)
688
+ # support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx ))
689
+ support_mask.append(itm["label"])
690
+
691
+ return {'class_ids': [support_class],
692
+ 'support_images': [support_images], #
693
+ 'support_mask': [support_mask],
694
+ 'scan_id': scan_id
695
+ }
696
+
697
+ def get_nii_dataset(config, image_size, **kwargs):
698
+ print(f"Check config: {config}")
699
+ organ_mapping = {
700
+ "sabs":{
701
+ "rk": 2,
702
+ "lk": 3,
703
+ "liver": 6,
704
+ "spleen": 1
705
+ },
706
+ "chaost2":{
707
+ "liver": 1,
708
+ "rk": 2,
709
+ "lk": 3,
710
+ "spleen": 4
711
+ }}
712
+
713
+ transforms = None
714
+ data_name = config['dataset']
715
+ if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672':
716
+ baseset_name = 'SABS'
717
+ max_label = 13
718
+ modality="CT"
719
+ elif data_name == 'C0_Superpix':
720
+ raise NotImplementedError
721
+ baseset_name = 'C0'
722
+ max_label = 3
723
+ elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672':
724
+ baseset_name = 'CHAOST2'
725
+ max_label = 4
726
+ modality="MR"
727
+ elif 'lits' in data_name.lower():
728
+ baseset_name = 'LITS17'
729
+ max_label = 4
730
+ else:
731
+ raise ValueError(f'Dataset: {data_name} not found')
732
+
733
+ # norm_func = get_normalize_op(modality=modality, fids=None) # TODO add global statistics
734
+ # norm_func = None
735
+
736
+ test_label = organ_mapping[baseset_name.lower()][config["curr_cls"]]
737
+ base_dir = config['path'][data_name]['data_dir']
738
+ testdataset = ManualAnnoDataset(which_dataset=baseset_name,
739
+ base_dir=base_dir,
740
+ idx_split = config['eval_fold'],
741
+ mode = 'val',
742
+ scan_per_load = 1,
743
+ transforms=transforms,
744
+ min_fg=1,
745
+ nsup = config["task"]["n_shots"],
746
+ fix_length=None,
747
+ image_size=image_size,
748
+ # extern_normalize_func=norm_func
749
+ **kwargs)
750
+
751
+ testdataset = ValidationDataset(testdataset, test_classes = [test_label], npart = config["task"]["npart"])
752
+ testdataset.set_curr_cls(test_label)
753
+
754
+ traindataset = None # TODO make this the support set later
755
+
756
+ return traindataset, testdataset
dataloaders/PolypDataset.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from https://github.com/talshaharabany/AutoSAM
3
+ """
4
+
5
+ import os
6
+ from PIL import Image
7
+ import torch.utils.data as data
8
+ import torchvision.transforms as transforms
9
+ import numpy as np
10
+ import random
11
+ import torch
12
+ from dataloaders.PolypTransforms import get_polyp_transform
13
+ import cv2
14
+ KVASIR = "Kvasir"
15
+ CLINIC_DB = "CVC-ClinicDB"
16
+ COLON_DB = "CVC-ColonDB"
17
+ ETIS_DB = "ETIS-LaribPolypDB"
18
+ CVC300 = "CVC-300"
19
+
20
+ DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB)
21
+ EXCLUDE_DS = (CVC300, )
22
+
23
+
24
+ def create_suppport_set_for_polyps(n_support=10):
25
+ """
26
+ create a text file contating n_support_images for each dataset
27
+ """
28
+ root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset"
29
+ supp_images = []
30
+ supp_masks = []
31
+
32
+ image_dir = os.path.join(root_dir, "images")
33
+ mask_dir = os.path.join(root_dir, "masks")
34
+ # randonly sample n_support images and masks
35
+ image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(
36
+ image_dir) if f.endswith('.jpg') or f.endswith('.png')])
37
+ mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(
38
+ mask_dir) if f.endswith('.png')])
39
+
40
+ while len(supp_images) < n_support:
41
+ index = random.randint(0, len(image_paths) - 1)
42
+ # check that the index is not already in the support set
43
+ if image_paths[index] in supp_images:
44
+ continue
45
+ supp_images.append(image_paths[index])
46
+ supp_masks.append(mask_paths[index])
47
+
48
+ with open(os.path.join(root_dir, "support.txt"), 'w') as file:
49
+ for image_path, mask_path in zip(supp_images, supp_masks):
50
+ file.write(f"{image_path} {mask_path}\n")
51
+
52
+ def create_train_val_test_split_for_polyps():
53
+ root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/"
54
+ # for each subdir in root_dir, create a split file
55
+ num_train_images_per_dataset = {
56
+ "CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0}
57
+
58
+ num_test_images_per_dataset = {
59
+ "CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380}
60
+
61
+ for subdir in os.listdir(root_dir):
62
+ subdir_path = os.path.join(root_dir, subdir)
63
+ if os.path.isdir(subdir_path):
64
+ split_file = os.path.join(subdir_path, "split.txt")
65
+ image_dir = os.path.join(subdir_path, "images")
66
+ create_train_val_test_split(
67
+ image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir])
68
+
69
+
70
+ def create_train_val_test_split(root, split_file, train_number=100, test_number=20):
71
+ """
72
+ Create a train, val, test split file for a dataset
73
+ root: root directory of dataset
74
+ split_file: name of split file to create
75
+ train_ratio: ratio of train set
76
+ val_ratio: ratio of val set
77
+ test_ratio: ratio of test set
78
+ """
79
+ # Get all files in root directory
80
+ files = os.listdir(root)
81
+ # Filter out non-image files, remove suffix
82
+ files = [f.split('.')[0]
83
+ for f in files if f.endswith('.jpg') or f.endswith('.png')]
84
+ # Shuffle files
85
+ random.shuffle(files)
86
+
87
+ # Calculate number of files for each split
88
+ num_files = len(files)
89
+ num_train = train_number
90
+ num_test = test_number
91
+ num_val = num_files - num_train - num_test
92
+ print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")
93
+ # Create splits
94
+ train = files[:num_train]
95
+ val = files[num_train:num_train + num_val]
96
+ test = files[num_train + num_val:]
97
+
98
+ # Write splits to file
99
+ with open(split_file, 'w') as file:
100
+ file.write("train\n")
101
+ for f in train:
102
+ file.write(f + "\n")
103
+ file.write("val\n")
104
+ for f in val:
105
+ file.write(f + "\n")
106
+ file.write("test\n")
107
+ for f in test:
108
+ file.write(f + "\n")
109
+
110
+
111
+ class PolypDataset(data.Dataset):
112
+ """
113
+ dataloader for polyp segmentation tasks
114
+ """
115
+
116
+ def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None):
117
+ self.trainsize = trainsize
118
+ self.augmentations = augmentations
119
+ self.datasets = datasets
120
+ if isinstance(image_size, int):
121
+ image_size = (image_size, image_size)
122
+ self.image_size = image_size
123
+ if image_root is not None and gt_root is not None:
124
+ self.images = [
125
+ os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
126
+ self.gts = [
127
+ os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')]
128
+ # also look in subdirectories
129
+ for subdir in os.listdir(image_root):
130
+ # if not dir, continue
131
+ if not os.path.isdir(os.path.join(image_root, subdir)):
132
+ continue
133
+ subdir_image_root = os.path.join(image_root, subdir)
134
+ subdir_gt_root = os.path.join(gt_root, subdir)
135
+ self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir(
136
+ subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')])
137
+ self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir(
138
+ subdir_gt_root) if f.endswith('.png')])
139
+
140
+ else:
141
+ self.images, self.gts = self.get_image_gt_pairs(
142
+ root, split="train" if train else "test", datasets=self.datasets)
143
+ self.images = sorted(self.images)
144
+ self.gts = sorted(self.gts)
145
+ if not 'VPS' in root:
146
+ self.filter_files_and_get_ds_mean_and_std()
147
+ if ds_mean is not None and ds_std is not None:
148
+ self.mean, self.std = ds_mean, ds_std
149
+ self.size = len(self.images)
150
+ self.train = train
151
+ self.sam_trans = sam_trans
152
+ if self.sam_trans is not None:
153
+ # sam trans takes care of norm
154
+ self.mean, self.std = 0 , 1
155
+
156
+ def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS):
157
+ """
158
+ for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file
159
+ dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders
160
+ split: train, val, or test
161
+ """
162
+ image_paths = []
163
+ gt_paths = []
164
+ for folder in os.listdir(dir_root):
165
+ if folder not in datasets:
166
+ continue
167
+ split_file = os.path.join(dir_root, folder, "split.txt")
168
+ if os.path.isfile(split_file):
169
+ image_root = os.path.join(dir_root, folder, "images")
170
+ gt_root = os.path.join(dir_root, folder, "masks")
171
+ image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file(
172
+ image_root, gt_root, split_file, split=split)
173
+ image_paths.extend(image_paths_tmp)
174
+ gt_paths.extend(gt_paths_tmp)
175
+ else:
176
+ print(
177
+ f"No split.txt file found in {os.path.join(dir_root, folder)}")
178
+
179
+ return image_paths, gt_paths
180
+
181
+ def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"):
182
+ """
183
+ image_root: root directory of images
184
+ gt_root: root directory of ground truth
185
+ text_file: text file containing train, val, test split with the following format:
186
+ train:
187
+ image1
188
+ image2
189
+ ...
190
+ val:
191
+ image1
192
+ image2
193
+ ...
194
+ test:
195
+ image1
196
+ image2
197
+ ...
198
+
199
+ split: train, val, or test
200
+ """
201
+ # Initialize a dictionary to hold file names for each split
202
+ splits = {"train": [], "val": [], "test": []}
203
+ current_split = None
204
+
205
+ # Read the text file and categorize file names under each split
206
+ with open(text_file, 'r') as file:
207
+ for line in file:
208
+ line = line.strip()
209
+ if line in splits:
210
+ current_split = line
211
+ elif line and current_split:
212
+ splits[current_split].append(line)
213
+
214
+ # Get the file names for the requested split
215
+ file_names = splits[split]
216
+
217
+ # Create image-ground truth pairs
218
+ image_paths = []
219
+ gt_paths = []
220
+ for name in file_names:
221
+ image_path = os.path.join(image_root, name + '.png')
222
+ gt_path = os.path.join(gt_root, name + '.png')
223
+ image_paths.append(image_path)
224
+ gt_paths.append(gt_path)
225
+
226
+ return image_paths, gt_paths
227
+
228
+ def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1):
229
+ support_images = []
230
+ support_labels = []
231
+ # get all images and masks
232
+ support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir(
233
+ support_image_dir) if f.endswith('.jpg') or f.endswith('.png')])
234
+ support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir(
235
+ support_mask_dir) if f.endswith('.png')])
236
+ # sample n_support images and masks
237
+ for i in range(n_support):
238
+ index = random.randint(0, len(support_image_paths) - 1)
239
+ support_img = self.cv2_loader(
240
+ support_image_paths[index], is_mask=False)
241
+ support_mask = self.cv2_loader(
242
+ support_mask_paths[index], is_mask=True)
243
+ support_images.append(support_img)
244
+ support_labels.append(support_mask)
245
+
246
+ if self.augmentations:
247
+ support_images = [self.augmentations(
248
+ img, mask)[0] for img, mask in zip(support_images, support_labels)]
249
+ support_labels = [self.augmentations(
250
+ img, mask)[1] for img, mask in zip(support_images, support_labels)]
251
+
252
+ support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images]
253
+
254
+ if self.sam_trans is not None:
255
+ support_images = [self.sam_trans.preprocess(
256
+ img).squeeze(0) for img in support_images]
257
+ support_labels = [self.sam_trans.preprocess(
258
+ mask) for mask in support_labels]
259
+ else:
260
+ image_size = self.image_size
261
+ support_images = [torch.nn.functional.interpolate(img.unsqueeze(
262
+ 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images]
263
+ support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(
264
+ 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels]
265
+
266
+ return torch.stack(support_images), torch.stack(support_labels)
267
+
268
+ def get_support_from_text_file(self, text_file, n_support=1):
269
+ """
270
+ each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path
271
+ """
272
+ support_images = []
273
+ support_labels = []
274
+ with open(text_file, 'r') as file:
275
+ for line in file:
276
+ image_path, mask_path = line.strip().split()
277
+ support_images.append(image_path)
278
+ support_labels.append(mask_path)
279
+
280
+ # indices = random.choices(range(len(support_images)), k=n_support)
281
+ if n_support > len(support_images):
282
+ raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})")
283
+
284
+ n_support_images = support_images[:n_support]
285
+ n_support_labels = support_labels[:n_support]
286
+
287
+ return n_support_images, n_support_labels
288
+
289
+ def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None):
290
+ """
291
+ Get support set from specified directories, text file or from the dataset itself
292
+ """
293
+ if support_image_dir is not None and support_mask_dir:
294
+ return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support)
295
+ elif text_file is not None:
296
+ support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support)
297
+ else:
298
+ # randomly sample n_support images and masks from the dataset
299
+ indices = random.choices(range(self.size), k=n_support)
300
+ # indices = list(range(n_support))
301
+ print(f"support indices:{indices}")
302
+ support_image_paths = [self.images[index] for index in indices]
303
+ support_gt_paths = [self.gts[index] for index in indices]
304
+
305
+ support_images = []
306
+ support_gts = []
307
+
308
+ for image_path, gt_path in zip(support_image_paths, support_gt_paths):
309
+ support_img = self.cv2_loader(image_path, is_mask=False)
310
+ support_mask = self.cv2_loader(gt_path, is_mask=True)
311
+ out = self.process_image_gt(support_img, support_mask)
312
+ support_images.append(out['image'].unsqueeze(0))
313
+ support_gts.append(out['label'].unsqueeze(0))
314
+ if len(support_images) >= n_support:
315
+ break
316
+ return support_images, support_gts, out['case']
317
+ # return torch.stack(support_images), torch.stack(support_gts), out['case']
318
+
319
+ def process_image_gt(self, image, gt, dataset=""):
320
+ """
321
+ image and gt are expected to be output from self.cv2_loader
322
+ """
323
+ original_size = tuple(image.shape[-2:])
324
+ if self.augmentations:
325
+ image, mask = self.augmentations(image, gt)
326
+
327
+ if self.sam_trans:
328
+ image, mask = self.sam_trans.apply_image_torch(
329
+ image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask)
330
+ elif image.max() <= 255 and image.min() >= 0:
331
+ image = (image - self.mean) / self.std
332
+ mask[mask > 0.5] = 1
333
+ mask[mask <= 0.5] = 0
334
+ # image_size = tuple(img.shape[-2:])
335
+
336
+ image_size = self.image_size
337
+ if self.sam_trans is None:
338
+ image = torch.nn.functional.interpolate(image.unsqueeze(
339
+ 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0)
340
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(
341
+ 0), size=image_size, mode='nearest').squeeze(0).squeeze(0)
342
+ # img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse
343
+
344
+ return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image,
345
+ 'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask,
346
+ 'original_size': torch.Tensor(original_size),
347
+ 'image_size': torch.Tensor(image_size),
348
+ 'case': dataset} # case to be compatible with polyp video dataset
349
+
350
+ def get_dataset_name_from_path(self, path):
351
+ for dataset in self.datasets:
352
+ if dataset in path:
353
+ return dataset
354
+ return ""
355
+
356
+ def __getitem__(self, index):
357
+ image = self.cv2_loader(self.images[index], is_mask=False)
358
+ gt = self.cv2_loader(self.gts[index], is_mask=True)
359
+ dataset = self.get_dataset_name_from_path(self.images[index])
360
+ return self.process_image_gt(image, gt, dataset)
361
+
362
+ def filter_files_and_get_ds_mean_and_std(self):
363
+ assert len(self.images) == len(self.gts)
364
+ images = []
365
+ gts = []
366
+ ds_mean = 0
367
+ ds_std = 0
368
+ for img_path, gt_path in zip(self.images, self.gts):
369
+ if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]):
370
+ continue
371
+ img = Image.open(img_path)
372
+ gt = Image.open(gt_path)
373
+ if img.size == gt.size:
374
+ images.append(img_path)
375
+ gts.append(gt_path)
376
+ ds_mean += np.array(img).mean()
377
+ ds_std += np.array(img).std()
378
+ self.images = images
379
+ self.gts = gts
380
+ self.mean = ds_mean / len(self.images)
381
+ self.std = ds_std / len(self.images)
382
+
383
+ def rgb_loader(self, path):
384
+ with open(path, 'rb') as f:
385
+ img = Image.open(f)
386
+ return img.convert('RGB')
387
+
388
+ def binary_loader(self, path):
389
+ # with open(path, 'rb') as f:
390
+ # img = Image.open(f)
391
+ # return img.convert('1')
392
+ img = cv2.imread(path, 0)
393
+ return img
394
+
395
+ def cv2_loader(self, path, is_mask):
396
+ if is_mask:
397
+ img = cv2.imread(path, 0)
398
+ img[img > 0] = 1
399
+ else:
400
+ img = cv2.cvtColor(cv2.imread(
401
+ path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
402
+ return img
403
+
404
+ def resize(self, img, gt):
405
+ assert img.size == gt.size
406
+ w, h = img.size
407
+ if h < self.trainsize or w < self.trainsize:
408
+ h = max(h, self.trainsize)
409
+ w = max(w, self.trainsize)
410
+ return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
411
+ else:
412
+ return img, gt
413
+
414
+ def __len__(self):
415
+ # return 32
416
+ return self.size
417
+
418
+
419
+ class SuperpixPolypDataset(PolypDataset):
420
+ def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None):
421
+ self.trainsize = trainsize
422
+ self.augmentations = augmentations
423
+ self.datasets = datasets
424
+ self.image_size = image_size
425
+ # print(self.augmentations)
426
+ if image_root is not None and gt_root is not None:
427
+ self.images = [
428
+ os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
429
+ self.gts = [
430
+ os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f]
431
+ # also look in subdirectories
432
+ for subdir in os.listdir(image_root):
433
+ # if not dir, continue
434
+ if not os.path.isdir(os.path.join(image_root, subdir)):
435
+ continue
436
+ subdir_image_root = os.path.join(image_root, subdir)
437
+ subdir_gt_root = os.path.join(gt_root, subdir)
438
+ self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir(
439
+ subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')])
440
+ self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir(
441
+ subdir_gt_root) if f.endswith('.png')])
442
+
443
+ else:
444
+ self.images, self.gts = self.get_image_gt_pairs(
445
+ root, split="train" if train else "test", datasets=self.datasets)
446
+ self.images = sorted(self.images)
447
+ self.gts = sorted(self.gts)
448
+ if not 'VPS' in root:
449
+ self.filter_files_and_get_ds_mean_and_std()
450
+ if ds_mean is not None and ds_std is not None:
451
+ self.mean, self.std = ds_mean, ds_std
452
+ self.size = len(self.images)
453
+ self.train = train
454
+ self.sam_trans = sam_trans
455
+ if self.sam_trans is not None:
456
+ # sam trans takes care of norm
457
+ self.mean, self.std = 0 , 1
458
+
459
+
460
+ def __getitem__(self, index):
461
+ image = self.cv2_loader(self.images[index], is_mask=False)
462
+ gt = self.cv2_loader(self.gts[index], is_mask=False)
463
+ gt = gt[:, :, 0]
464
+ fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_')
465
+ fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png')
466
+ fg = self.cv2_loader(fgpath, is_mask=True)
467
+ dataset = self.get_dataset_name_from_path(self.images[index])
468
+
469
+ # randomly choose a superpixels from the gt
470
+ gt[1-fg] = 0
471
+ sp_id = random.choice(np.unique(gt)[1:])
472
+ sp = (gt == sp_id).astype(np.uint8)
473
+
474
+
475
+ out = self.process_image_gt(image, gt, dataset)
476
+ support_image, support_sp, dataset = out["image"], out["label"], out["case"]
477
+
478
+ out = self.process_image_gt(image, sp, dataset)
479
+ query_image, query_sp, dataset = out["image"], out["label"], out["case"]
480
+
481
+ # TODO tile the masks to have 3 channels?
482
+
483
+ support_bg_mask = 1 - support_sp
484
+ support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask}
485
+
486
+ batch = {"support_images" : [[support_image]],
487
+ "support_mask" : [[support_masks]],
488
+ "query_images" : [query_image],
489
+ "query_labels" : [query_sp],
490
+ "scan_id" : [dataset]
491
+ }
492
+
493
+ return batch
494
+
495
+
496
+ def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None):
497
+ transform_train, transform_test = get_polyp_transform()
498
+ image_root = './data/PolypDataset/TrainDataset/images/'
499
+ gt_root = './data/PolypDataset/TrainDataset/superpixels/'
500
+ ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root,
501
+ augmentations=transform_train,
502
+ sam_trans=sam_trans,
503
+ image_size=image_size)
504
+
505
+ return ds_train
506
+
507
+ def get_polyp_dataset(image_size, sam_trans=None):
508
+ transform_train, transform_test = get_polyp_transform()
509
+ image_root = './data/PolypDataset/TrainDataset/images/'
510
+ gt_root = './data/PolypDataset/TrainDataset/masks/'
511
+ ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root,
512
+ augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size)
513
+ image_root = './data/PolypDataset/TestDataset/test/images/'
514
+ gt_root = './data/PolypDataset/TestDataset/test/masks/'
515
+ ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False,
516
+ augmentations=transform_test, sam_trans=sam_trans, image_size=image_size)
517
+ return ds_train, ds_test
518
+
519
+
520
+ def get_tests_polyp_dataset(sam_trans):
521
+ transform_train, transform_test = get_polyp_transform()
522
+
523
+ image_root = './data/polyp/TestDataset/Kvasir/images/'
524
+ gt_root = './data/polyp/TestDataset/Kvasir/masks/'
525
+ ds_Kvasir = PolypDataset(
526
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
527
+
528
+ image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/'
529
+ gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/'
530
+ ds_ClinicDB = PolypDataset(
531
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
532
+
533
+ image_root = './data/polyp/TestDataset/CVC-ColonDB/images/'
534
+ gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/'
535
+ ds_ColonDB = PolypDataset(
536
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
537
+
538
+ image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/'
539
+ gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/'
540
+ ds_ETIS = PolypDataset(
541
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
542
+
543
+ return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS
544
+
545
+
546
+ if __name__ == '__main__':
547
+ # create_train_val_test_split_for_polyps()
548
+ create_suppport_set_for_polyps()
dataloaders/PolypTransforms.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import math
4
+ import sys
5
+ import random
6
+ from PIL import Image
7
+
8
+ try:
9
+ import accimage
10
+ except ImportError:
11
+ accimage = None
12
+ import numpy as np
13
+ import numbers
14
+ import types
15
+ import collections
16
+ import warnings
17
+
18
+ from torchvision.transforms import functional as F
19
+
20
+ if sys.version_info < (3, 3):
21
+ Sequence = collections.Sequence
22
+ Iterable = collections.Iterable
23
+ else:
24
+ Sequence = collections.abc.Sequence
25
+ Iterable = collections.abc.Iterable
26
+
27
+ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
28
+ "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
29
+ "RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop",
30
+ "ColorJitter", "RandomRotation", "RandomAffine",
31
+ "RandomPerspective"]
32
+
33
+ _pil_interpolation_to_str = {
34
+ Image.NEAREST: 'PIL.Image.NEAREST',
35
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
36
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
37
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
38
+ Image.HAMMING: 'PIL.Image.HAMMING',
39
+ Image.BOX: 'PIL.Image.BOX',
40
+ }
41
+
42
+
43
+ class Compose(object):
44
+ def __init__(self, transforms):
45
+ self.transforms = transforms
46
+
47
+ def __call__(self, img, mask):
48
+ for t in self.transforms:
49
+ img, mask = t(img, mask)
50
+ return img, mask
51
+
52
+
53
+ class ToTensor(object):
54
+ def __call__(self, img, mask):
55
+ # return F.to_tensor(img), F.to_tensor(mask)
56
+ img = np.array(img)
57
+ img = torch.from_numpy(img).permute(2, 0, 1).float() # TODO add division by 255 to match torch.ToTensor()?
58
+ mask = torch.from_numpy(np.array(mask)).float()
59
+ return img, mask
60
+
61
+
62
+ class ToPILImage(object):
63
+ def __init__(self, mode=None):
64
+ self.mode = mode
65
+
66
+ def __call__(self, img, mask):
67
+ return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode)
68
+
69
+
70
+ class Normalize(object):
71
+ def __init__(self, mean, std, inplace=False):
72
+ self.mean = mean
73
+ self.std = std
74
+ self.inplace = inplace
75
+
76
+ def __call__(self, img, mask):
77
+ return F.normalize(img, self.mean, self.std, self.inplace), mask
78
+
79
+
80
+ class Resize(object):
81
+ def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
82
+ assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
83
+ self.size = size
84
+ self.interpolation = interpolation
85
+ self.do_mask = do_mask
86
+
87
+ def __call__(self, img, mask):
88
+ if self.do_mask:
89
+ return F.resize(img, self.size, Image.BICUBIC), F.resize(mask, self.size, Image.BICUBIC)
90
+ else:
91
+ return F.resize(img, self.size, Image.BICUBIC), mask
92
+
93
+
94
+ class CenterCrop(object):
95
+ def __init__(self, size):
96
+ if isinstance(size, numbers.Number):
97
+ self.size = (int(size), int(size))
98
+ else:
99
+ self.size = size
100
+
101
+ def __call__(self, img, mask):
102
+ return F.center_crop(img, self.size), F.center_crop(mask, self.size)
103
+
104
+
105
+ class Pad(object):
106
+ def __init__(self, padding, fill=0, padding_mode='constant'):
107
+ assert isinstance(padding, (numbers.Number, tuple))
108
+ assert isinstance(fill, (numbers.Number, str, tuple))
109
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
110
+ if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
111
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
112
+ "{} element tuple".format(len(padding)))
113
+
114
+ self.padding = padding
115
+ self.fill = fill
116
+ self.padding_mode = padding_mode
117
+
118
+ def __call__(self, img, mask):
119
+ return F.pad(img, self.padding, self.fill, self.padding_mode), \
120
+ F.pad(mask, self.padding, self.fill, self.padding_mode)
121
+
122
+
123
+ class Lambda(object):
124
+ def __init__(self, lambd):
125
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
126
+ self.lambd = lambd
127
+
128
+ def __call__(self, img, mask):
129
+ return self.lambd(img), self.lambd(mask)
130
+
131
+
132
+ class Lambda_image(object):
133
+ def __init__(self, lambd):
134
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
135
+ self.lambd = lambd
136
+
137
+ def __call__(self, img, mask):
138
+ return self.lambd(img), mask
139
+
140
+
141
+ class RandomTransforms(object):
142
+ def __init__(self, transforms):
143
+ assert isinstance(transforms, (list, tuple))
144
+ self.transforms = transforms
145
+
146
+ def __call__(self, *args, **kwargs):
147
+ raise NotImplementedError()
148
+
149
+
150
+ class RandomApply(RandomTransforms):
151
+ def __init__(self, transforms, p=0.5):
152
+ super(RandomApply, self).__init__(transforms)
153
+ self.p = p
154
+
155
+ def __call__(self, img, mask):
156
+ if self.p < random.random():
157
+ return img, mask
158
+ for t in self.transforms:
159
+ img, mask = t(img, mask)
160
+ return img, mask
161
+
162
+
163
+ class RandomOrder(RandomTransforms):
164
+ def __call__(self, img, mask):
165
+ order = list(range(len(self.transforms)))
166
+ random.shuffle(order)
167
+ for i in order:
168
+ img, mask = self.transforms[i](img, mask)
169
+ return img, mask
170
+
171
+
172
+ class RandomChoice(RandomTransforms):
173
+ def __call__(self, img, mask):
174
+ t = random.choice(self.transforms)
175
+ return t(img, mask)
176
+
177
+
178
+ class RandomCrop(object):
179
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
180
+ if isinstance(size, numbers.Number):
181
+ self.size = (int(size), int(size))
182
+ else:
183
+ self.size = size
184
+ self.padding = padding
185
+ self.pad_if_needed = pad_if_needed
186
+ self.fill = fill
187
+ self.padding_mode = padding_mode
188
+
189
+ @staticmethod
190
+ def get_params(img, output_size):
191
+ w, h = img.size
192
+ th, tw = output_size
193
+ if w == tw and h == th:
194
+ return 0, 0, h, w
195
+
196
+ i = random.randint(0, h - th)
197
+ j = random.randint(0, w - tw)
198
+ return i, j, th, tw
199
+
200
+ def __call__(self, img, mask):
201
+ if self.padding is not None:
202
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
203
+
204
+ # pad the width if needed
205
+ if self.pad_if_needed and img.size[0] < self.size[1]:
206
+ img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
207
+ # pad the height if needed
208
+ if self.pad_if_needed and img.size[1] < self.size[0]:
209
+ img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
210
+
211
+ i, j, h, w = self.get_params(img, self.size)
212
+
213
+ return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
214
+
215
+
216
+ class RandomHorizontalFlip(object):
217
+ def __init__(self, p=0.5):
218
+ self.p = p
219
+
220
+ def __call__(self, img, mask):
221
+ if random.random() < self.p:
222
+ return F.hflip(img), F.hflip(mask)
223
+ return img, mask
224
+
225
+
226
+ class RandomVerticalFlip(object):
227
+ def __init__(self, p=0.5):
228
+ self.p = p
229
+
230
+ def __call__(self, img, mask):
231
+ if random.random() < self.p:
232
+ return F.vflip(img), F.vflip(mask)
233
+ return img, mask
234
+
235
+
236
+ class RandomPerspective(object):
237
+ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
238
+ self.p = p
239
+ self.interpolation = interpolation
240
+ self.distortion_scale = distortion_scale
241
+
242
+ def __call__(self, img, mask):
243
+ if not F._is_pil_image(img):
244
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
245
+
246
+ if random.random() < self.p:
247
+ width, height = img.size
248
+ startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
249
+ return F.perspective(img, startpoints, endpoints, self.interpolation), \
250
+ F.perspective(mask, startpoints, endpoints, Image.NEAREST)
251
+ return img, mask
252
+
253
+ @staticmethod
254
+ def get_params(width, height, distortion_scale):
255
+ half_height = int(height / 2)
256
+ half_width = int(width / 2)
257
+ topleft = (random.randint(0, int(distortion_scale * half_width)),
258
+ random.randint(0, int(distortion_scale * half_height)))
259
+ topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
260
+ random.randint(0, int(distortion_scale * half_height)))
261
+ botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
262
+ random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
263
+ botleft = (random.randint(0, int(distortion_scale * half_width)),
264
+ random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
265
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
266
+ endpoints = [topleft, topright, botright, botleft]
267
+ return startpoints, endpoints
268
+
269
+
270
+ class RandomResizedCrop(object):
271
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
272
+ if isinstance(size, tuple):
273
+ self.size = size
274
+ else:
275
+ self.size = (size, size)
276
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
277
+ warnings.warn("range should be of kind (min, max)")
278
+
279
+ self.interpolation = interpolation
280
+ self.scale = scale
281
+ self.ratio = ratio
282
+
283
+ @staticmethod
284
+ def get_params(img, scale, ratio):
285
+ area = img.size[0] * img.size[1]
286
+
287
+ for attempt in range(10):
288
+ target_area = random.uniform(*scale) * area
289
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
290
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
291
+
292
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
293
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
294
+
295
+ if w <= img.size[0] and h <= img.size[1]:
296
+ i = random.randint(0, img.size[1] - h)
297
+ j = random.randint(0, img.size[0] - w)
298
+ return i, j, h, w
299
+
300
+ # Fallback to central crop
301
+ in_ratio = img.size[0] / img.size[1]
302
+ if (in_ratio < min(ratio)):
303
+ w = img.size[0]
304
+ h = w / min(ratio)
305
+ elif (in_ratio > max(ratio)):
306
+ h = img.size[1]
307
+ w = h * max(ratio)
308
+ else: # whole image
309
+ w = img.size[0]
310
+ h = img.size[1]
311
+ i = (img.size[1] - h) // 2
312
+ j = (img.size[0] - w) // 2
313
+ return i, j, h, w
314
+
315
+ def __call__(self, img, mask):
316
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
317
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
318
+ F.resized_crop(mask, i, j, h, w, self.size, Image.NEAREST)
319
+
320
+
321
+ class FiveCrop(object):
322
+ def __init__(self, size):
323
+ self.size = size
324
+ if isinstance(size, numbers.Number):
325
+ self.size = (int(size), int(size))
326
+ else:
327
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
328
+ self.size = size
329
+
330
+ def __call__(self, img, mask):
331
+ return F.five_crop(img, self.size), F.five_crop(mask, self.size)
332
+
333
+
334
+ class TenCrop(object):
335
+ def __init__(self, size, vertical_flip=False):
336
+ self.size = size
337
+ if isinstance(size, numbers.Number):
338
+ self.size = (int(size), int(size))
339
+ else:
340
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
341
+ self.size = size
342
+ self.vertical_flip = vertical_flip
343
+
344
+ def __call__(self, img, mask):
345
+ return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip)
346
+
347
+
348
+ class ColorJitter(object):
349
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
350
+ self.brightness = self._check_input(brightness, 'brightness')
351
+ self.contrast = self._check_input(contrast, 'contrast')
352
+ self.saturation = self._check_input(saturation, 'saturation')
353
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
354
+ clip_first_on_zero=False)
355
+
356
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
357
+ if isinstance(value, numbers.Number):
358
+ if value < 0:
359
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
360
+ value = [center - value, center + value]
361
+ if clip_first_on_zero:
362
+ value[0] = max(value[0], 0)
363
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
364
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
365
+ raise ValueError("{} values should be between {}".format(name, bound))
366
+ else:
367
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
368
+
369
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
370
+ # or (0., 0.) for hue, do nothing
371
+ if value[0] == value[1] == center:
372
+ value = None
373
+ return value
374
+
375
+ @staticmethod
376
+ def get_params(brightness, contrast, saturation, hue):
377
+ transforms = []
378
+
379
+ if brightness is not None:
380
+ brightness_factor = random.uniform(brightness[0], brightness[1])
381
+ transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor)))
382
+
383
+ if contrast is not None:
384
+ contrast_factor = random.uniform(contrast[0], contrast[1])
385
+ transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor)))
386
+
387
+ if saturation is not None:
388
+ saturation_factor = random.uniform(saturation[0], saturation[1])
389
+ transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor)))
390
+
391
+ if hue is not None:
392
+ hue_factor = random.uniform(hue[0], hue[1])
393
+ transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor)))
394
+
395
+ random.shuffle(transforms)
396
+ transform = Compose(transforms)
397
+
398
+ return transform
399
+
400
+ def __call__(self, img, mask):
401
+ transform = self.get_params(self.brightness, self.contrast,
402
+ self.saturation, self.hue)
403
+ return transform(img, mask)
404
+
405
+
406
+ class RandomRotation(object):
407
+ def __init__(self, degrees, resample=False, expand=False, center=None):
408
+ if isinstance(degrees, numbers.Number):
409
+ if degrees < 0:
410
+ raise ValueError("If degrees is a single number, it must be positive.")
411
+ self.degrees = (-degrees, degrees)
412
+ else:
413
+ if len(degrees) != 2:
414
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
415
+ self.degrees = degrees
416
+
417
+ self.resample = resample
418
+ self.expand = expand
419
+ self.center = center
420
+
421
+ @staticmethod
422
+ def get_params(degrees):
423
+ angle = random.uniform(degrees[0], degrees[1])
424
+
425
+ return angle
426
+
427
+ def __call__(self, img, mask):
428
+ angle = self.get_params(self.degrees)
429
+
430
+ return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \
431
+ F.rotate(mask, angle, Image.NEAREST, self.expand, self.center)
432
+
433
+
434
+ class RandomAffine(object):
435
+ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
436
+ if isinstance(degrees, numbers.Number):
437
+ if degrees < 0:
438
+ raise ValueError("If degrees is a single number, it must be positive.")
439
+ self.degrees = (-degrees, degrees)
440
+ else:
441
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
442
+ "degrees should be a list or tuple and it must be of length 2."
443
+ self.degrees = degrees
444
+
445
+ if translate is not None:
446
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
447
+ "translate should be a list or tuple and it must be of length 2."
448
+ for t in translate:
449
+ if not (0.0 <= t <= 1.0):
450
+ raise ValueError("translation values should be between 0 and 1")
451
+ self.translate = translate
452
+
453
+ if scale is not None:
454
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
455
+ "scale should be a list or tuple and it must be of length 2."
456
+ for s in scale:
457
+ if s <= 0:
458
+ raise ValueError("scale values should be positive")
459
+ self.scale = scale
460
+
461
+ if shear is not None:
462
+ if isinstance(shear, numbers.Number):
463
+ if shear < 0:
464
+ raise ValueError("If shear is a single number, it must be positive.")
465
+ self.shear = (-shear, shear)
466
+ else:
467
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
468
+ "shear should be a list or tuple and it must be of length 2."
469
+ self.shear = shear
470
+ else:
471
+ self.shear = shear
472
+
473
+ self.resample = resample
474
+ self.fillcolor = fillcolor
475
+
476
+ @staticmethod
477
+ def get_params(degrees, translate, scale_ranges, shears, img_size):
478
+ angle = random.uniform(degrees[0], degrees[1])
479
+ if translate is not None:
480
+ max_dx = translate[0] * img_size[0]
481
+ max_dy = translate[1] * img_size[1]
482
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
483
+ np.round(random.uniform(-max_dy, max_dy)))
484
+ else:
485
+ translations = (0, 0)
486
+
487
+ if scale_ranges is not None:
488
+ scale = random.uniform(scale_ranges[0], scale_ranges[1])
489
+ else:
490
+ scale = 1.0
491
+
492
+ if shears is not None:
493
+ shear = random.uniform(shears[0], shears[1])
494
+ else:
495
+ shear = 0.0
496
+
497
+ return angle, translations, scale, shear
498
+
499
+ def __call__(self, img, mask):
500
+ ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
501
+ return F.affine(img, *ret, interpolation=Image.BILINEAR, fill=self.fillcolor), \
502
+ F.affine(mask, *ret, interpolation=Image.NEAREST, fill=self.fillcolor)
503
+
504
+
505
+
506
+ def get_cub_transform():
507
+ transform_train = Compose([
508
+ ToPILImage(),
509
+ Resize((256, 256)),
510
+ RandomHorizontalFlip(),
511
+ RandomAffine(22, scale=(0.75, 1.25)),
512
+ ToTensor(),
513
+ Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
514
+ ])
515
+ transform_test = Compose([
516
+ ToPILImage(),
517
+ Resize((256, 256)),
518
+ ToTensor(),
519
+ Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
520
+ ])
521
+ return transform_train, transform_test
522
+
523
+
524
+ def get_glas_transform():
525
+ transform_train = Compose([
526
+ ToPILImage(),
527
+ # Resize((256, 256)),
528
+ ColorJitter(brightness=0.2,
529
+ contrast=0.2,
530
+ saturation=0.2,
531
+ hue=0.1),
532
+ RandomHorizontalFlip(),
533
+ RandomAffine(5, scale=(0.75, 1.25)),
534
+ ToTensor(),
535
+ # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
536
+ ])
537
+ transform_test = Compose([
538
+ ToPILImage(),
539
+ # Resize((256, 256)),
540
+ ToTensor(),
541
+ # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
542
+ ])
543
+ return transform_train, transform_test
544
+
545
+ # def get_glas_transform():
546
+ # transform_train = Compose([
547
+ # ToPILImage(),
548
+ # Resize((256, 256)),
549
+ # ColorJitter(brightness=0.2,
550
+ # contrast=0.2,
551
+ # saturation=0.2,
552
+ # hue=0.1),
553
+ # RandomHorizontalFlip(),
554
+ # RandomAffine(5, scale=(0.75, 1.25)),
555
+ # ToTensor(),
556
+ # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
557
+ # ])
558
+ # transform_test = Compose([
559
+ # ToPILImage(),
560
+ # Resize((256, 256)),
561
+ # ToTensor(),
562
+ # Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
563
+ # ])
564
+ # return transform_train, transform_test
565
+
566
+
567
+ def get_monu_transform(args):
568
+ Idim = int(args['Idim'])
569
+ transform_train = Compose([
570
+ ToPILImage(),
571
+ # Resize((Idim, Idim)),
572
+ ColorJitter(brightness=0.4,
573
+ contrast=0.4,
574
+ saturation=0.4,
575
+ hue=0.1),
576
+ RandomHorizontalFlip(),
577
+ RandomAffine(int(args['rotate']), scale=(float(args['scale1']), float(args['scale2']))),
578
+ ToTensor(),
579
+ # Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
580
+ ])
581
+ transform_test = Compose([
582
+ ToPILImage(),
583
+ # Resize((Idim, Idim)),
584
+ ToTensor(),
585
+ # Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
586
+ ])
587
+ return transform_train, transform_test
588
+
589
+
590
+ def get_polyp_transform():
591
+ transform_train = Compose([
592
+ # Resize((352, 352)),
593
+ ToPILImage(),
594
+ ColorJitter(brightness=0.4,
595
+ contrast=0.4,
596
+ saturation=0.4,
597
+ hue=0.1),
598
+ RandomVerticalFlip(),
599
+ RandomHorizontalFlip(),
600
+ RandomAffine(90, scale=(0.75, 1.25)),
601
+ ToTensor(),
602
+ # Normalize([105.61, 63.69, 45.67],
603
+ # [83.08, 55.86, 42.59])
604
+ ])
605
+ transform_test = Compose([
606
+ # Resize((352, 352)),
607
+ ToPILImage(),
608
+ ToTensor(),
609
+ # Normalize([105.61, 63.69, 45.67],
610
+ # [83.08, 55.86, 42.59])
611
+ ])
612
+ return transform_train, transform_test
613
+
614
+
615
+ def get_polyp_support_train_transform():
616
+ transform_train = Compose([
617
+ ColorJitter(brightness=0.4,
618
+ contrast=0.4,
619
+ saturation=0.4,
620
+ hue=0.1),
621
+ RandomVerticalFlip(),
622
+ RandomHorizontalFlip(),
623
+ RandomAffine(90, scale=(0.75, 1.25)),
624
+ ])
625
+
626
+ return transform_train
dataloaders/SimpleDataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ """
6
+ simple dataset, gets the images and masks as list together with a transform function that
7
+ shoudl receive both the image and the mask.
8
+ loop means how many times to loop the dataset per epoch
9
+ """
10
+
11
+ class SimpleDataset(torch.utils.data.Dataset):
12
+ def __init__(self, image_list, mask_list, transform=None, norm_func=None, loops=10, modality="", debug=False, image_size=None):
13
+ self.image_list = image_list
14
+ if image_size is not None:
15
+ if len(image_size) == 1:
16
+ image_size = (image_size, image_size)
17
+ self.image_size = image_size
18
+ else:
19
+ self.image_size = image_list[0].shape[-2:]
20
+ self.mask_list = mask_list
21
+ self.transform = transform
22
+ self.norm_func = norm_func
23
+ self.loops = loops
24
+ self.modality = modality
25
+ self.debug = debug
26
+
27
+ def __len__(self):
28
+ return len(self.image_list) * self.loops
29
+
30
+ def __getitem__(self, idx):
31
+ idx = idx % (len(self.image_list))
32
+ image = self.image_list[idx].numpy()
33
+ mask = self.mask_list[idx].to(dtype=torch.uint8).numpy()
34
+ if self.modality == "CT":
35
+ image = image.astype(np.uint8)
36
+ if self.transform:
37
+ image, mask = self.transform(image, mask)
38
+ else:
39
+ # mask = np.repeat(mask[..., np.newaxis], 3, axis=-1)
40
+ if self.transform:
41
+ image, mask = self.transform(image, mask)
42
+
43
+ if self.norm_func:
44
+ image = self.norm_func(image)
45
+
46
+ mask[mask != 0] = 1
47
+
48
+ if self.image_size != image.shape[-2:]:
49
+ image = torch.nn.functional.interpolate(torch.tensor(image).unsqueeze(0), self.image_size, mode='bilinear').squeeze(0)
50
+ mask = torch.nn.functional.interpolate(torch.tensor(mask).unsqueeze(0).unsqueeze(0), self.image_size, mode='nearest').squeeze(0).squeeze(0)
51
+
52
+ # plot image and mask
53
+ if self.debug:
54
+ fig = plt.figure()
55
+ plt.imshow((image[0]- image.min()) / (image.max() - image.min()))
56
+ plt.imshow(mask, alpha=0.5)
57
+ plt.savefig("debug/support_image_mask.png")
58
+ plt.close(fig)
59
+
60
+ image_size = torch.tensor(tuple(image.shape[-2:]))
61
+ return image, mask
dataloaders/__init__.py ADDED
File without changes
dataloaders/augutils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Utilities for augmentation. Partly credit to Dr. Jo Schlemper
3
+ '''
4
+ from os.path import join
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torchvision.transforms as deftfx
9
+ import dataloaders.image_transforms as myit
10
+ import copy
11
+ from util.consts import IMG_SIZE
12
+ import time
13
+ import functools
14
+
15
+
16
+ def get_sabs_aug(input_size, use_3d=False):
17
+ sabs_aug = {
18
+ # turn flipping off as medical data has fixed orientations
19
+ 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
20
+ 'affine': {
21
+ 'rotate': 5,
22
+ 'shift': (5, 5),
23
+ 'shear': 5,
24
+ 'scale': (0.9, 1.2),
25
+ },
26
+ 'elastic': {'alpha': 10, 'sigma': 5},
27
+ 'patch': input_size,
28
+ 'reduce_2d': True,
29
+ '3d': use_3d,
30
+ 'gamma_range': (0.5, 1.5)
31
+ }
32
+ return sabs_aug
33
+
34
+
35
+ def get_sabs_augv3(input_size):
36
+ sabs_augv3 = {
37
+ 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
38
+ 'affine': {
39
+ 'rotate': 30,
40
+ 'shift': (30, 30),
41
+ 'shear': 30,
42
+ 'scale': (0.8, 1.3),
43
+ },
44
+ 'elastic': {'alpha': 20, 'sigma': 5},
45
+ 'patch': input_size,
46
+ 'reduce_2d': True,
47
+ 'gamma_range': (0.2, 1.8)
48
+ }
49
+ return sabs_augv3
50
+
51
+
52
+ def get_aug(which_aug, input_size):
53
+ if which_aug == 'sabs_aug':
54
+ return get_sabs_aug(input_size)
55
+ elif which_aug == 'aug_v3':
56
+ return get_sabs_augv3(input_size)
57
+ else:
58
+ raise NotImplementedError
59
+
60
+ # augs = {
61
+ # 'sabs_aug': get_sabs_aug,
62
+ # 'aug_v3': get_sabs_augv3, # more aggresive
63
+ # }
64
+
65
+
66
+ def get_geometric_transformer(aug, order=3):
67
+ """order: interpolation degree. Select order=0 for augmenting segmentation """
68
+ affine = aug['aug'].get('affine', 0)
69
+ alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
70
+ sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
71
+ flip = aug['aug'].get(
72
+ 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})
73
+
74
+ tfx = []
75
+ if 'flip' in aug['aug']:
76
+ tfx.append(myit.RandomFlip3D(**flip))
77
+
78
+ if 'affine' in aug['aug']:
79
+ tfx.append(myit.RandomAffine(affine.get('rotate'),
80
+ affine.get('shift'),
81
+ affine.get('shear'),
82
+ affine.get('scale'),
83
+ affine.get('scale_iso', True),
84
+ order=order))
85
+
86
+ if 'elastic' in aug['aug']:
87
+ tfx.append(myit.ElasticTransform(alpha, sigma))
88
+ input_transform = deftfx.Compose(tfx)
89
+ return input_transform
90
+
91
+
92
+ def get_geometric_transformer_3d(aug, order=3):
93
+ """order: interpolation degree. Select order=0 for augmenting segmentation """
94
+ affine = aug['aug'].get('affine', 0)
95
+ alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
96
+ sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
97
+ flip = aug['aug'].get(
98
+ 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})
99
+
100
+ tfx = []
101
+ if 'flip' in aug['aug']:
102
+ tfx.append(myit.RandomFlip3D(**flip))
103
+
104
+ if 'affine' in aug['aug']:
105
+ tfx.append(myit.RandomAffine(affine.get('rotate'),
106
+ affine.get('shift'),
107
+ affine.get('shear'),
108
+ affine.get('scale'),
109
+ affine.get('scale_iso', True),
110
+ order=order,
111
+ use_3d=True))
112
+
113
+ if 'elastic' in aug['aug']:
114
+ tfx.append(myit.ElasticTransform(alpha, sigma))
115
+ input_transform = deftfx.Compose(tfx)
116
+ return input_transform
117
+
118
+
119
+ def gamma_transform(img, aug):
120
+ gamma_range = aug['aug']['gamma_range']
121
+ if isinstance(gamma_range, tuple):
122
+ gamma = np.random.rand() * \
123
+ (gamma_range[1] - gamma_range[0]) + gamma_range[0]
124
+ cmin = img.min()
125
+ irange = (img.max() - cmin + 1e-5)
126
+
127
+ img = img - cmin + 1e-5
128
+ img = irange * np.power(img * 1.0 / irange, gamma)
129
+ img = img + cmin
130
+
131
+ elif gamma_range == False:
132
+ pass
133
+ else:
134
+ raise ValueError(
135
+ "Cannot identify gamma transform range {}".format(gamma_range))
136
+ return img
137
+
138
+
139
+ def get_intensity_transformer(aug):
140
+ """some basic intensity transforms"""
141
+ return functools.partial(gamma_transform, aug=aug)
142
+
143
+
144
+ def transform_with_label(aug):
145
+ """
146
+ Doing image geometric transform
147
+ Proposed image to have the following configurations
148
+ [H x W x C + CL]
149
+ Where CL is the number of channels for the label. It is NOT in one-hot form
150
+ """
151
+
152
+ geometric_tfx = get_geometric_transformer(aug)
153
+ intensity_tfx = get_intensity_transformer(aug)
154
+
155
+ def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs):
156
+ """
157
+ Args
158
+ comp: a numpy array with shape [H x W x C + c_label]
159
+ c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1)
160
+ nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label
161
+
162
+ """
163
+ comp = copy.deepcopy(comp)
164
+ if (use_onehot is True) and (c_label != 1):
165
+ raise NotImplementedError(
166
+ "Only allow compact label, also the label can only be 2d")
167
+ assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label"
168
+
169
+ # geometric transform
170
+ _label = comp[..., c_img]
171
+ _h_label = np.float32(np.arange(nclass) == (_label[..., None]))
172
+ # _h_label = np.float32(_label[..., None])
173
+ comp = np.concatenate([comp[..., :c_img], _h_label], -1)
174
+ comp = geometric_tfx(comp)
175
+ # round one_hot labels to 0 or 1
176
+ t_label_h = comp[..., c_img:]
177
+ t_label_h = np.rint(t_label_h)
178
+ assert t_label_h.max() <= 1
179
+ t_img = comp[..., 0: c_img]
180
+
181
+ # intensity transform
182
+ t_img = intensity_tfx(t_img)
183
+
184
+ if use_onehot is True:
185
+ t_label = t_label_h
186
+ else:
187
+ t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1)
188
+ return t_img, t_label
189
+
190
+ return transform
191
+
192
+
193
+ def transform(scan, label, nclass, geometric_tfx, intensity_tfx):
194
+ """
195
+ Args
196
+ scan: a numpy array with shape [D x H x W x C]
197
+ label: a numpy array with shape [D x H x W x 1]
198
+ """
199
+ assert len(scan.shape) == 4, "Input scan must be 4D"
200
+ if len(label.shape) == 3:
201
+ label = np.expand_dims(label, -1)
202
+
203
+ # geometric transform
204
+ comp = copy.deepcopy(np.concatenate(
205
+ [scan, label], -1)) # [D x H x W x C + 1]
206
+ _label = comp[..., -1]
207
+ _h_label = np.float32(np.arange(nclass) == (_label[..., None]))
208
+ comp = np.concatenate([comp[..., :-1], _h_label], -1)
209
+ # change comp to be H x W x D x C + 1
210
+ comp = np.transpose(comp, (1, 2, 0, 3))
211
+ comp = geometric_tfx(comp)
212
+ t_label_h = comp[..., 1:]
213
+ t_label_h = np.rint(t_label_h)
214
+ assert t_label_h.max() <= 1
215
+ t_img = comp[..., 0:1]
216
+
217
+ # intensity transform
218
+ t_img = intensity_tfx(t_img)
219
+ return t_img, t_label_h
220
+
221
+
222
+ def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx):
223
+ return transform(scan, label, nclass, geometric_tfx, intensity_tfx)
224
+
dataloaders/common.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset classes for common uses
3
+ Extended from vanilla PANet code by Wang et al.
4
+ """
5
+ import random
6
+ import torch
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+ class BaseDataset(Dataset):
11
+ """
12
+ Base Dataset
13
+ Args:
14
+ base_dir:
15
+ dataset directory
16
+ """
17
+ def __init__(self, base_dir):
18
+ self._base_dir = base_dir
19
+ self.aux_attrib = {}
20
+ self.aux_attrib_args = {}
21
+ self.ids = [] # must be overloaded in subclass
22
+
23
+ def add_attrib(self, key, func, func_args):
24
+ """
25
+ Add attribute to the data sample dict
26
+
27
+ Args:
28
+ key:
29
+ key in the data sample dict for the new attribute
30
+ e.g. sample['click_map'], sample['depth_map']
31
+ func:
32
+ function to process a data sample and create an attribute (e.g. user clicks)
33
+ func_args:
34
+ extra arguments to pass, expected a dict
35
+ """
36
+ if key in self.aux_attrib:
37
+ raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key))
38
+ else:
39
+ self.set_attrib(key, func, func_args)
40
+
41
+ def set_attrib(self, key, func, func_args):
42
+ """
43
+ Set attribute in the data sample dict
44
+
45
+ Args:
46
+ key:
47
+ key in the data sample dict for the new attribute
48
+ e.g. sample['click_map'], sample['depth_map']
49
+ func:
50
+ function to process a data sample and create an attribute (e.g. user clicks)
51
+ func_args:
52
+ extra arguments to pass, expected a dict
53
+ """
54
+ self.aux_attrib[key] = func
55
+ self.aux_attrib_args[key] = func_args
56
+
57
+ def del_attrib(self, key):
58
+ """
59
+ Remove attribute in the data sample dict
60
+
61
+ Args:
62
+ key:
63
+ key in the data sample dict
64
+ """
65
+ self.aux_attrib.pop(key)
66
+ self.aux_attrib_args.pop(key)
67
+
68
+ def subsets(self, sub_ids, sub_args_lst=None):
69
+ """
70
+ Create subsets by ids
71
+
72
+ Args:
73
+ sub_ids:
74
+ a sequence of sequences, each sequence contains data ids for one subset
75
+ sub_args_lst:
76
+ a list of args for some subset-specific auxiliary attribute function
77
+ """
78
+
79
+ indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids]
80
+ if sub_args_lst is not None:
81
+ subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args)
82
+ for index, args in zip(indices, sub_args_lst)]
83
+ else:
84
+ subsets = [Subset(dataset=self, indices=index) for index in indices]
85
+ return subsets
86
+
87
+ def __len__(self):
88
+ pass
89
+
90
+ def __getitem__(self, idx):
91
+ pass
92
+
93
+
94
+ class ReloadPairedDataset(Dataset):
95
+ """
96
+ Make pairs of data from dataset
97
+ Eable only loading part of the entire data in each epoach and then reload to the next part
98
+ Args:
99
+ datasets:
100
+ source datasets, expect a list of Dataset.
101
+ Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan
102
+ n_elements:
103
+ number of elements in a pair
104
+ curr_max_iters:
105
+ number of pairs in an epoch
106
+ pair_based_transforms:
107
+ some transformation performed on a pair basis, expect a list of functions,
108
+ each function takes a pair sample and return a transformed one.
109
+ """
110
+ def __init__(self, datasets, n_elements, curr_max_iters,
111
+ pair_based_transforms=None):
112
+ super().__init__()
113
+ self.datasets = datasets
114
+ self.n_datasets = len(self.datasets)
115
+ self.n_data = [len(dataset) for dataset in self.datasets]
116
+ self.n_elements = n_elements
117
+ self.curr_max_iters = curr_max_iters
118
+ self.pair_based_transforms = pair_based_transforms
119
+ self.update_index()
120
+
121
+ def update_index(self):
122
+ """
123
+ update the order of batches for the next episode
124
+ """
125
+
126
+ # update number of elements for each subset
127
+ if hasattr(self, 'indices'):
128
+ n_data_old = self.n_data # DEBUG
129
+ self.n_data = [len(dataset) for dataset in self.datasets]
130
+
131
+ if isinstance(self.n_elements, list):
132
+ self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use
133
+ for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use
134
+ for i_iter in range(self.curr_max_iters)] # sample <self.curr_max_iters> iterations
135
+
136
+ elif self.n_elements > self.n_datasets:
137
+ raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets")
138
+ else:
139
+ self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx]))
140
+ for dataset_idx in random.sample(range(self.n_datasets),
141
+ k=n_elements)]
142
+ for i in range(curr_max_iters)]
143
+
144
+ def __len__(self):
145
+ return self.curr_max_iters
146
+
147
+ def __getitem__(self, idx):
148
+ sample = [self.datasets[dataset_idx][data_idx]
149
+ for dataset_idx, data_idx in self.indices[idx]]
150
+ if self.pair_based_transforms is not None:
151
+ for transform, args in self.pair_based_transforms:
152
+ sample = transform(sample, **args)
153
+ return sample
154
+
155
+ class Subset(Dataset):
156
+ """
157
+ Subset of a dataset at specified indices. Used for seperating a dataset by class in our context
158
+
159
+ Args:
160
+ dataset:
161
+ The whole Dataset
162
+ indices:
163
+ Indices of samples of the current class in the entire dataset
164
+ sub_attrib_args:
165
+ Subset-specific arguments for attribute functions, expected a dict
166
+ """
167
+ def __init__(self, dataset, indices, sub_attrib_args=None):
168
+ self.dataset = dataset
169
+ self.indices = indices
170
+ self.sub_attrib_args = sub_attrib_args
171
+
172
+ def __getitem__(self, idx):
173
+ if self.sub_attrib_args is not None:
174
+ for key in self.sub_attrib_args:
175
+ # Make sure the dataset already has the corresponding attributes
176
+ # Here we only make the arguments subset dependent
177
+ # (i.e. pass different arguments for each subset)
178
+ self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key])
179
+ return self.dataset[self.indices[idx]]
180
+
181
+ def __len__(self):
182
+ return len(self.indices)
183
+
184
+ class ValidationDataset(Dataset):
185
+ """
186
+ Dataset for validation
187
+
188
+ Args:
189
+ dataset:
190
+ source dataset with a __getitem__ method
191
+ test_classes:
192
+ test classes
193
+ npart: int. number of parts, used for evaluation when assigning support images
194
+
195
+ """
196
+ def __init__(self, dataset, test_classes: list, npart: int):
197
+ super().__init__()
198
+ self.dataset = dataset
199
+ self.__curr_cls = None
200
+ self.test_classes = test_classes
201
+ self.dataset.aux_attrib = None
202
+ self.npart = npart
203
+
204
+ def set_curr_cls(self, curr_cls):
205
+ assert curr_cls in self.test_classes
206
+ self.__curr_cls = curr_cls
207
+
208
+ def get_curr_cls(self):
209
+ return self.__curr_cls
210
+
211
+ def read_dataset(self):
212
+ """
213
+ override original read_dataset to allow reading with z_margin
214
+ """
215
+ raise NotImplementedError
216
+
217
+ def __len__(self):
218
+ return len(self.dataset)
219
+
220
+ def label_strip(self, label):
221
+ """
222
+ mask unrelated labels out
223
+ """
224
+ out = torch.where(label == self.__curr_cls,
225
+ torch.ones_like(label), torch.zeros_like(label))
226
+ return out
227
+
228
+ def __getitem__(self, idx):
229
+ if self.__curr_cls is None:
230
+ raise Exception("Please initialize current class first")
231
+
232
+ sample = self.dataset[idx]
233
+ sample["label"] = self.label_strip( sample["label"] )
234
+ sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy()
235
+
236
+ labelname = self.dataset.all_label_names[self.__curr_cls]
237
+ z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
238
+ z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
239
+ sample["z_min"], sample["z_max"] = z_min, z_max
240
+ try:
241
+ part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart))
242
+ except:
243
+ part_assign = 0
244
+ # print("###### DATASET: support only has one valid slice ######")
245
+ if part_assign < 0:
246
+ part_assign = 0
247
+ elif part_assign >= self.npart:
248
+ part_assign = self.npart - 1
249
+ sample["part_assign"] = part_assign
250
+ sample["case"] = sample["scan_id"]
251
+
252
+ return sample
253
+
254
+ def get_support_set(self, config, n_support=3):
255
+ support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"])
256
+
257
+ support_images = [img for way in support_batched["support_images"] for img in way]
258
+ support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way]
259
+ support_scan_id = self.dataset.potential_support_sid
260
+ return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id}
261
+
262
+
263
+
dataloaders/dataset_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for datasets
3
+ """
4
+ import functools
5
+ import numpy as np
6
+
7
+ import os
8
+ import sys
9
+ import nibabel as nib
10
+ import numpy as np
11
+ import pdb
12
+ import SimpleITK as sitk
13
+
14
+ DATASET_INFO = {
15
+ "CHAOST2": {
16
+ 'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
17
+ 'REAL_LABEL_NAME': ["BG", "LIVER", "RK", "LK", "SPLEEN"],
18
+ '_SEP': [0, 4, 8, 12, 16, 20],
19
+ 'MODALITY': 'MR',
20
+ 'LABEL_GROUP': {
21
+ 'pa_all': set(range(1, 5)),
22
+ 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes
23
+ 1: set([2, 3]), # lower_abdomen
24
+ },
25
+ },
26
+
27
+ "SABS": {
28
+ 'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
29
+
30
+ 'REAL_LABEL_NAME': ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC",\
31
+ "PS_VEIN", "PANCREAS", "AG_R", "AG_L"],
32
+ '_SEP': [0, 6, 12, 18, 24, 30],
33
+ 'MODALITY': 'CT',
34
+ 'LABEL_GROUP':{
35
+ 'pa_all': set( [1,2,3,6] ),
36
+ 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing
37
+ 1: set( [2,3] ), # lower_abdomen
38
+ }
39
+ },
40
+ "LITS17": {
41
+ 'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
42
+
43
+ 'REAL_LABEL_NAME': ["BGD", "LIVER", "TUMOR"],
44
+ '_SEP': [0, 26, 52, 78, 104],
45
+ 'MODALITY': 'CT',
46
+ 'LABEL_GROUP':{
47
+ 'pa_all': set( [1 , 2] ),
48
+ 0: set([1 ] ), # liver
49
+ 1: set( [ 2] ), # tumor
50
+ 2: set([1,2]) # liver + tumor
51
+ }
52
+
53
+ }
54
+
55
+ }
56
+
57
+ def read_nii_bysitk(input_fid, peel_info = False):
58
+ """ read nii to numpy through simpleitk
59
+
60
+ peelinfo: taking direction, origin, spacing and metadata out
61
+ """
62
+ img_obj = sitk.ReadImage(input_fid)
63
+ img_np = sitk.GetArrayFromImage(img_obj)
64
+ if peel_info:
65
+ info_obj = {
66
+ "spacing": img_obj.GetSpacing(),
67
+ "origin": img_obj.GetOrigin(),
68
+ "direction": img_obj.GetDirection(),
69
+ "array_size": img_np.shape
70
+ }
71
+ return img_np, info_obj
72
+ else:
73
+ return img_np
74
+
75
+
76
+ def get_CT_statistics(scan_fids):
77
+ """
78
+ As CT are quantitative, get mean and std for CT images for image normalizing
79
+ As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading
80
+ """
81
+ total_val = 0
82
+ n_pix = 0
83
+ for fid in scan_fids:
84
+ in_img = read_nii_bysitk(fid)
85
+ total_val += in_img.sum()
86
+ n_pix += np.prod(in_img.shape)
87
+ del in_img
88
+ meanval = total_val / n_pix
89
+
90
+ total_var = 0
91
+ for fid in scan_fids:
92
+ in_img = read_nii_bysitk(fid)
93
+ total_var += np.sum((in_img - meanval) ** 2 )
94
+ del in_img
95
+ var_all = total_var / n_pix
96
+
97
+ global_std = var_all ** 0.5
98
+
99
+ return meanval, global_std
100
+
101
+ def MR_normalize(x_in):
102
+ return (x_in - x_in.mean()) / x_in.std()
103
+
104
+ def CT_normalize(x_in, ct_mean, ct_std):
105
+ """
106
+ Normalizing CT images, based on global statistics
107
+ """
108
+ return (x_in - ct_mean) / ct_std
109
+
110
+ def get_normalize_op(modality, fids, ct_mean=None, ct_std=None):
111
+ """
112
+ As title
113
+ Args:
114
+ modality: CT or MR
115
+ fids: fids for the fold
116
+ """
117
+ if modality == 'MR':
118
+ return MR_normalize
119
+
120
+ elif modality == 'CT':
121
+ if ct_mean is None or ct_std is None:
122
+ ct_mean, ct_std = get_CT_statistics(fids)
123
+ # debug
124
+ print(f'###### DEBUG_DATASET CT_STATS NORMALIZED MEAN {ct_mean} STD {ct_std} ######')
125
+
126
+ return functools.partial(CT_normalize, ct_mean=ct_mean, ct_std=ct_std)
127
+
128
+
dataloaders/dev_customized_med.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Customized dataset. Extended from vanilla PANet script by Wang et al.
3
+ """
4
+
5
+ import os
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+
10
+ from dataloaders.common import ReloadPairedDataset, ValidationDataset
11
+ from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset
12
+
13
+ def attrib_basic(_sample, class_id):
14
+ """
15
+ Add basic attribute
16
+ Args:
17
+ _sample: data sample
18
+ class_id: class label asscociated with the data
19
+ (sometimes indicting from which subset the data are drawn)
20
+ """
21
+ return {'class_id': class_id}
22
+
23
+ def getMaskOnly(label, class_id, class_ids):
24
+ """
25
+ Generate FG/BG mask from the segmentation mask
26
+
27
+ Args:
28
+ label:
29
+ semantic mask
30
+ scribble:
31
+ scribble mask
32
+ class_id:
33
+ semantic class of interest
34
+ class_ids:
35
+ all class id in this episode
36
+ """
37
+ # Dense Mask
38
+ fg_mask = torch.where(label == class_id,
39
+ torch.ones_like(label), torch.zeros_like(label))
40
+ bg_mask = torch.where(label != class_id,
41
+ torch.ones_like(label), torch.zeros_like(label))
42
+ for class_id in class_ids:
43
+ bg_mask[label == class_id] = 0
44
+
45
+ return {'fg_mask': fg_mask,
46
+ 'bg_mask': bg_mask}
47
+
48
+ def getMasks(*args, **kwargs):
49
+ raise NotImplementedError
50
+
51
+ def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True):
52
+ """
53
+ Postprocess paired sample for fewshot settings
54
+ For now only 1-way is tested but we leave multi-way possible (inherited from original PANet)
55
+
56
+ Args:
57
+ paired_sample:
58
+ data sample from a PairedDataset
59
+ n_ways:
60
+ n-way few-shot learning
61
+ n_shots:
62
+ n-shot few-shot learning
63
+ cnt_query:
64
+ number of query images for each class in the support set
65
+ coco:
66
+ MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension
67
+ mask_only:
68
+ only give masks and no scribbles/ instances. Suitable for medical images (for now)
69
+ """
70
+ if not mask_only:
71
+ raise NotImplementedError
72
+ ###### Compose the support and query image list ######
73
+ cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries
74
+
75
+ # support class ids
76
+ class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query)
77
+
78
+ # support images
79
+ support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)]
80
+ for i in range(n_ways)] # fetch support images for each class
81
+
82
+ # support image labels
83
+ if coco:
84
+ support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]]
85
+ for j in range(n_shots)] for i in range(n_ways)]
86
+ else:
87
+ support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)]
88
+ for i in range(n_ways)]
89
+
90
+ if not mask_only:
91
+ support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)]
92
+ for i in range(n_ways)]
93
+ support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)]
94
+ for i in range(n_ways)]
95
+ else:
96
+ support_insts = []
97
+
98
+ # query images, masks and class indices
99
+ query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways)
100
+ for j in range(cnt_query[i])]
101
+ if coco:
102
+ query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]]
103
+ for i in range(n_ways) for j in range(cnt_query[i])]
104
+ else:
105
+ query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways)
106
+ for j in range(cnt_query[i])]
107
+ query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1
108
+ for x in set(np.unique(query_label)) & set(class_ids)])
109
+ for query_label in query_labels]
110
+
111
+ ###### Generate support image masks ######
112
+ if not mask_only:
113
+ support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot],
114
+ class_ids[way], class_ids)
115
+ for shot in range(n_shots)] for way in range(n_ways)]
116
+ else:
117
+ support_mask = [[getMaskOnly(support_labels[way][shot],
118
+ class_ids[way], class_ids)
119
+ for shot in range(n_shots)] for way in range(n_ways)]
120
+
121
+ ###### Generate query label (class indices in one episode, i.e. the ground truth)######
122
+ query_labels_tmp = [torch.zeros_like(x) for x in query_labels]
123
+ for i, query_label_tmp in enumerate(query_labels_tmp):
124
+ query_label_tmp[query_labels[i] == 255] = 255
125
+ for j in range(n_ways):
126
+ query_label_tmp[query_labels[i] == class_ids[j]] = j + 1
127
+
128
+ ###### Generate query mask for each semantic class (including BG) ######
129
+ # BG class
130
+ query_masks = [[torch.where(query_label == 0,
131
+ torch.ones_like(query_label),
132
+ torch.zeros_like(query_label))[None, ...],]
133
+ for query_label in query_labels]
134
+ # Other classes in query image
135
+ for i, query_label in enumerate(query_labels):
136
+ for idx in query_cls_idx[i][1:]:
137
+ mask = torch.where(query_label == class_ids[idx - 1],
138
+ torch.ones_like(query_label),
139
+ torch.zeros_like(query_label))[None, ...]
140
+ query_masks[i].append(mask)
141
+
142
+
143
+ return {'class_ids': class_ids,
144
+ 'support_images': support_images,
145
+ 'support_mask': support_mask,
146
+ 'support_inst': support_insts, # leave these interfaces
147
+ 'support_scribbles': support_scribbles,
148
+
149
+ 'query_images': query_images,
150
+ 'query_labels': query_labels_tmp,
151
+ 'query_masks': query_masks,
152
+ 'query_cls_idx': query_cls_idx,
153
+ }
154
+
155
+
156
+ def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load,
157
+ transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs):
158
+ """
159
+ Dataset wrapper
160
+ Args:
161
+ dataset_name:
162
+ indicates what dataset to use
163
+ base_dir:
164
+ dataset directory
165
+ mode:
166
+ which mode to use
167
+ choose from ('train', 'val', 'trainval', 'trainaug')
168
+ idx_split:
169
+ index of split
170
+ scan_per_load:
171
+ number of scans to load into memory as the dataset is large
172
+ use that together with reload_buffer
173
+ transforms:
174
+ transformations to be performed on images/masks
175
+ act_labels:
176
+ active labels involved in training process. Should be a subset of all labels
177
+ n_ways:
178
+ n-way few-shot learning, should be no more than # of object class labels
179
+ n_shots:
180
+ n-shot few-shot learning
181
+ max_iters_per_load:
182
+ number of pairs per load (epoch size)
183
+ n_queries:
184
+ number of query images
185
+ fix_parent_len:
186
+ fixed length of the parent dataset
187
+ """
188
+ med_set = ManualAnnoDataset
189
+
190
+
191
+ mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\
192
+ scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\
193
+ exclude_list = exclude_list, **kwargs)
194
+
195
+ mydataset.add_attrib('basic', attrib_basic, {})
196
+
197
+ # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside
198
+ subsets = mydataset.subsets([{'basic': {'class_id': ii}}
199
+ for ii, _ in enumerate(mydataset.label_name)])
200
+
201
+ # Choose the classes of queries
202
+ cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways)
203
+ # Number of queries for each way
204
+ # Set the number of images for each class
205
+ n_elements = [n_shots + x for x in cnt_query] # <n_shot> supports + <cnt_quert>[i] queries
206
+ # Create paired dataset. We do not include background.
207
+ paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load,
208
+ pair_based_transforms=[
209
+ (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots,
210
+ 'cnt_query': cnt_query, 'mask_only': True})])
211
+ return paired_data, mydataset
212
+
213
+ def update_loader_dset(loader, parent_set):
214
+ """
215
+ Update data loader and the parent dataset behind
216
+ Args:
217
+ loader: actual dataloader
218
+ parent_set: parent dataset which actually stores the data
219
+ """
220
+ parent_set.reload_buffer()
221
+ loader.dataset.update_index()
222
+ print(f'###### Loader and dataset have been updated ######' )
223
+
224
+ def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs):
225
+ """
226
+ validation set for med images
227
+ Args:
228
+ dataset_name:
229
+ indicates what dataset to use
230
+ base_dir:
231
+ SABS dataset directory
232
+ mode: (original split)
233
+ which split to use
234
+ choose from ('train', 'val', 'trainval', 'trainaug')
235
+ idx_split:
236
+ index of split
237
+ scan_per_batch:
238
+ number of scans to load into memory as the dataset is large
239
+ use that together with reload_buffer
240
+ act_labels:
241
+ actual labels involved in training process. Should be a subset of all labels
242
+ npart: number of chunks for splitting a 3d volume
243
+ nsup: number of support scans, equivalent to nshot
244
+ """
245
+ mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs)
246
+ mydataset.add_attrib('basic', attrib_basic, {})
247
+
248
+ valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart)
249
+
250
+ return valset, mydataset
dataloaders/image_transforms.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image transforms functions for data augmentation
3
+ Credit to Dr. Jo Schlemper
4
+ """
5
+
6
+ from collections.abc import Sequence
7
+ import cv2
8
+ import numpy as np
9
+ import scipy
10
+ from scipy.ndimage.filters import gaussian_filter
11
+ from scipy.ndimage.interpolation import map_coordinates
12
+ from numpy.lib.stride_tricks import as_strided
13
+ import numpy as np
14
+ import cv2
15
+ from scipy.ndimage import map_coordinates
16
+ from numpy.lib.stride_tricks import as_strided
17
+ from multiprocessing import Pool
18
+ import albumentations as A
19
+ import time
20
+
21
+ ###### UTILITIES ######
22
+ def random_num_generator(config, random_state=np.random):
23
+ if config[0] == 'uniform':
24
+ ret = random_state.uniform(config[1], config[2], 1)[0]
25
+ elif config[0] == 'lognormal':
26
+ ret = random_state.lognormal(config[1], config[2], 1)[0]
27
+ else:
28
+ #print(config)
29
+ raise Exception('unsupported format')
30
+ return ret
31
+
32
+ def get_translation_matrix(translation):
33
+ """ translation: [tx, ty] """
34
+ tx, ty = translation
35
+ translation_matrix = np.array([[1, 0, tx],
36
+ [0, 1, ty],
37
+ [0, 0, 1]])
38
+ return translation_matrix
39
+
40
+
41
+
42
+ def get_rotation_matrix(rotation, input_shape, centred=True):
43
+ theta = np.pi / 180 * np.array(rotation)
44
+ if centred:
45
+ rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1)
46
+ rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]])
47
+ else:
48
+ rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
49
+ [np.sin(theta), np.cos(theta), 0],
50
+ [0, 0, 1]])
51
+ return rotation_matrix
52
+
53
+ def get_zoom_matrix(zoom, input_shape, centred=True):
54
+ zx, zy = zoom
55
+ if centred:
56
+ zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0])
57
+ zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]])
58
+ else:
59
+ zoom_matrix = np.array([[zx, 0, 0],
60
+ [0, zy, 0],
61
+ [0, 0, 1]])
62
+ return zoom_matrix
63
+
64
+ def get_shear_matrix(shear_angle):
65
+ theta = (np.pi * shear_angle) / 180
66
+ shear_matrix = np.array([[1, -np.sin(theta), 0],
67
+ [0, np.cos(theta), 0],
68
+ [0, 0, 1]])
69
+ return shear_matrix
70
+
71
+ ###### AFFINE TRANSFORM ######
72
+ class RandomAffine(object):
73
+ """Apply random affine transformation on a numpy.ndarray (H x W x C)
74
+ Comment by co1818: this is still doing affine on 2d (H x W plane).
75
+ A same transform is applied to all C channels
76
+
77
+ Parameter:
78
+ ----------
79
+
80
+ alpha: Range [0, 4] seems good for small images
81
+
82
+ order: interpolation method (c.f. opencv)
83
+ """
84
+
85
+ def __init__(self,
86
+ rotation_range=None,
87
+ translation_range=None,
88
+ shear_range=None,
89
+ zoom_range=None,
90
+ zoom_keep_aspect=False,
91
+ interp='bilinear',
92
+ use_3d=False,
93
+ order=3):
94
+ """
95
+ Perform an affine transforms.
96
+
97
+ Arguments
98
+ ---------
99
+ rotation_range : one integer or float
100
+ image will be rotated randomly between (-degrees, degrees)
101
+
102
+ translation_range : (x_shift, y_shift)
103
+ shifts in pixels
104
+
105
+ *NOT TESTED* shear_range : float
106
+ image will be sheared randomly between (-degrees, degrees)
107
+
108
+ zoom_range : (zoom_min, zoom_max)
109
+ list/tuple with two floats between [0, infinity).
110
+ first float should be less than the second
111
+ lower and upper bounds on percent zoom.
112
+ Anything less than 1.0 will zoom in on the image,
113
+ anything greater than 1.0 will zoom out on the image.
114
+ e.g. (0.7, 1.0) will only zoom in,
115
+ (1.0, 1.4) will only zoom out,
116
+ (0.7, 1.4) will randomly zoom in or out
117
+ """
118
+
119
+ self.rotation_range = rotation_range
120
+ self.translation_range = translation_range
121
+ self.shear_range = shear_range
122
+ self.zoom_range = zoom_range
123
+ self.zoom_keep_aspect = zoom_keep_aspect
124
+ self.interp = interp
125
+ self.order = order
126
+ self.use_3d = use_3d
127
+
128
+ def build_M(self, input_shape):
129
+ tfx = []
130
+ final_tfx = np.eye(3)
131
+ if self.rotation_range:
132
+ rot = np.random.uniform(-self.rotation_range, self.rotation_range)
133
+ tfx.append(get_rotation_matrix(rot, input_shape))
134
+ if self.translation_range:
135
+ tx = np.random.uniform(-self.translation_range[0], self.translation_range[0])
136
+ ty = np.random.uniform(-self.translation_range[1], self.translation_range[1])
137
+ tfx.append(get_translation_matrix((tx,ty)))
138
+ if self.shear_range:
139
+ rot = np.random.uniform(-self.shear_range, self.shear_range)
140
+ tfx.append(get_shear_matrix(rot))
141
+ if self.zoom_range:
142
+ sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1])
143
+ if self.zoom_keep_aspect:
144
+ sy = sx
145
+ else:
146
+ sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1])
147
+
148
+ tfx.append(get_zoom_matrix((sx, sy), input_shape))
149
+
150
+ for tfx_mat in tfx:
151
+ final_tfx = np.dot(tfx_mat, final_tfx)
152
+
153
+ return final_tfx.astype(np.float32)
154
+
155
+ def __call__(self, image):
156
+ # build matrix
157
+ input_shape = image.shape[:2]
158
+ M = self.build_M(input_shape)
159
+
160
+ res = np.zeros_like(image)
161
+ #if isinstance(self.interp, Sequence):
162
+ if type(self.order) is list or type(self.order) is tuple:
163
+ for i, intp in enumerate(self.order):
164
+ if self.use_3d:
165
+ res[..., i] = affine_transform_3d_via_M(image[..., i], M[:2], interp=intp)
166
+ else:
167
+ res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp)
168
+ else:
169
+ # squeeze if needed
170
+ orig_shape = image.shape
171
+ image_s = np.squeeze(image)
172
+ if self.use_3d:
173
+ res = affine_transform_3d_via_M(image_s, M[:2], interp=self.order)
174
+ else:
175
+ res = affine_transform_via_M(image_s, M[:2], interp=self.order)
176
+ res = res.reshape(orig_shape)
177
+
178
+ #res = affine_transform_via_M(image, M[:2], interp=self.order)
179
+
180
+ return res
181
+
182
+ def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST):
183
+ imshape = image.shape
184
+ shape_size = imshape[:2]
185
+
186
+ # Random affine
187
+ warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1],
188
+ flags=interp, borderMode=borderMode)
189
+
190
+ #print(imshape, warped.shape)
191
+
192
+ warped = warped[..., np.newaxis].reshape(imshape)
193
+
194
+ return warped
195
+
196
+ def affine_transform_3d_via_M(vol, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST):
197
+ """
198
+ vol should be of shape (nx, ny, n1, ..., nm)
199
+ """
200
+ # go over slice slice
201
+ res = np.zeros_like(vol)
202
+ for i in range(vol.shape[2]):
203
+ res[:, :, i] = affine_transform_via_M(vol[:,:,i], M, borderMode=borderMode, interp=interp)
204
+
205
+ return res
206
+
207
+
208
+ ###### ELASTIC TRANSFORM ######
209
+ def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random):
210
+ """Elastic deformation of image as described in [Simard2003]_.
211
+ .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
212
+ Convolutional Neural Networks applied to Visual Document Analysis", in
213
+ Proc. of the International Conference on Document Analysis and
214
+ Recognition, 2003.
215
+ """
216
+ assert image.ndim == 3
217
+ shape = image.shape[:2]
218
+
219
+ dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
220
+ sigma, mode="constant", cval=0) * alpha
221
+ dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
222
+ sigma, mode="constant", cval=0) * alpha
223
+
224
+ x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
225
+ indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
226
+ result = np.empty_like(image)
227
+ for i in range(image.shape[2]):
228
+ result[:, :, i] = map_coordinates(
229
+ image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
230
+ return result
231
+
232
+ def elastic_transform_nd_3d(image, **kwargs):
233
+ """
234
+ image_w_mask should be of shape (nx, ny, nz, 3)
235
+ """
236
+ image_w_mask = image
237
+ start_time = time.time()
238
+ elastic_transform = A.ElasticTransform(alpha=10, sigma=20, alpha_affine=15, interpolation=1, border_mode=4, always_apply=True, p=0.5)
239
+ # print(f"elastic transform initilization took {time.time() - start_time} seconds")
240
+ img = image_w_mask[..., 0]
241
+ label = image_w_mask[..., -1]
242
+ transformed = elastic_transform(image=img, mask=label)
243
+ t_img = transformed['image'][..., np.newaxis]
244
+ t_mask = transformed['mask'][..., np.newaxis]
245
+ t_mask_bg = 1 - t_mask
246
+ t_mask = np.concatenate([t_mask_bg, t_mask], axis=-1)
247
+
248
+ comp = np.concatenate([t_img, t_mask], axis=-1)
249
+ return comp
250
+
251
+ def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False):
252
+ """Expects data to be (nx, ny, n1 ,..., nm)
253
+ params:
254
+ ------
255
+
256
+ alpha:
257
+ the scaling parameter.
258
+ E.g.: alpha=2 => distorts images up to 2x scaling
259
+
260
+ sigma:
261
+ standard deviation of gaussian filter.
262
+ E.g.
263
+ low (sig~=1e-3) => no smoothing, pixelated.
264
+ high (1/5 * imsize) => smooth, more like affine.
265
+ very high (1/2*im_size) => translation
266
+ """
267
+
268
+ if random_state is None:
269
+ random_state = np.random.RandomState(None)
270
+
271
+ shape = image.shape
272
+ imsize = shape[:2]
273
+ dim = shape[2:]
274
+
275
+ # Random affine
276
+ blur_size = int(4*sigma) | 1
277
+ dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1,
278
+ ksize=(blur_size, blur_size), sigmaX=sigma) * alpha
279
+ dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1,
280
+ ksize=(blur_size, blur_size), sigmaX=sigma) * alpha
281
+
282
+ # use as_strided to copy things over across n1...nn channels
283
+ dx = as_strided(dx.astype(np.float32),
284
+ strides=(0,) * len(dim) + (4*shape[1], 4),
285
+ shape=dim+(shape[0], shape[1]))
286
+ dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim))))
287
+
288
+ dy = as_strided(dy.astype(np.float32),
289
+ strides=(0,) * len(dim) + (4*shape[1], 4),
290
+ shape=dim+(shape[0], shape[1]))
291
+ dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim))))
292
+
293
+ coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim])
294
+ indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:],
295
+ [dy, dx] + [0] * len(dim))]
296
+
297
+ if lazy:
298
+ return indices
299
+ res = map_coordinates(image, indices, order=order, mode='reflect').reshape(shape)
300
+ return res
301
+
302
+ class ElasticTransform(object):
303
+ """Apply elastic transformation on a numpy.ndarray (H x W x C)
304
+ """
305
+
306
+ def __init__(self, alpha, sigma, order=1):
307
+ self.alpha = alpha
308
+ self.sigma = sigma
309
+ self.order = order
310
+
311
+ def __call__(self, image):
312
+ if isinstance(self.alpha, Sequence):
313
+ alpha = random_num_generator(self.alpha)
314
+ else:
315
+ alpha = self.alpha
316
+ if isinstance(self.sigma, Sequence):
317
+ sigma = random_num_generator(self.sigma)
318
+ else:
319
+ sigma = self.sigma
320
+ return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order)
321
+
322
+ class RandomFlip3D(object):
323
+
324
+ def __init__(self, h=True, v=True, t=True, p=0.5):
325
+ """
326
+ Randomly flip an image horizontally and/or vertically with
327
+ some probability.
328
+
329
+ Arguments
330
+ ---------
331
+ h : boolean
332
+ whether to horizontally flip w/ probability p
333
+
334
+ v : boolean
335
+ whether to vertically flip w/ probability p
336
+
337
+ p : float between [0,1]
338
+ probability with which to apply allowed flipping operations
339
+ """
340
+ self.horizontal = h
341
+ self.vertical = v
342
+ self.depth = t
343
+ self.p = p
344
+
345
+ def __call__(self, x, y=None):
346
+ # horizontal flip with p = self.p
347
+ if self.horizontal:
348
+ if np.random.random() < self.p:
349
+ x = x[::-1, ...]
350
+
351
+ # vertical flip with p = self.p
352
+ if self.vertical:
353
+ if np.random.random() < self.p:
354
+ x = x[:, ::-1, ...]
355
+
356
+ if self.depth:
357
+ if np.random.random() < self.p:
358
+ x = x[..., ::-1]
359
+
360
+ return x
361
+
362
+
dataloaders/niftiio.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for datasets
3
+ """
4
+ import numpy as np
5
+
6
+ import numpy as np
7
+ import SimpleITK as sitk
8
+
9
+
10
+ def read_nii_bysitk(input_fid, peel_info = False):
11
+ """ read nii to numpy through simpleitk
12
+ peelinfo: taking direction, origin, spacing and metadata out
13
+ """
14
+ img_obj = sitk.ReadImage(input_fid)
15
+ img_np = sitk.GetArrayFromImage(img_obj)
16
+ if peel_info:
17
+ info_obj = {
18
+ "spacing": img_obj.GetSpacing(),
19
+ "origin": img_obj.GetOrigin(),
20
+ "direction": img_obj.GetDirection(),
21
+ "array_size": img_np.shape
22
+ }
23
+ return img_np, info_obj
24
+ else:
25
+ return img_np
26
+
27
+ def convert_to_sitk(input_mat, peeled_info):
28
+ """
29
+ write a numpy array to sitk image object with essential meta-data
30
+ """
31
+ nii_obj = sitk.GetImageFromArray(input_mat)
32
+ if peeled_info:
33
+ nii_obj.SetSpacing( peeled_info["spacing"] )
34
+ nii_obj.SetOrigin( peeled_info["origin"] )
35
+ nii_obj.SetDirection(peeled_info["direction"] )
36
+ return nii_obj
37
+
38
+ def np2itk(img, ref_obj):
39
+ """
40
+ img: numpy array
41
+ ref_obj: reference sitk object for copying information from
42
+ """
43
+ itk_obj = sitk.GetImageFromArray(img)
44
+ itk_obj.SetSpacing( ref_obj.GetSpacing() )
45
+ itk_obj.SetOrigin( ref_obj.GetOrigin() )
46
+ itk_obj.SetDirection( ref_obj.GetDirection() )
47
+ return itk_obj
48
+
models/ProtoMedSAM.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from models.ProtoSAM import ModelWrapper
7
+ from segment_anything import sam_model_registry
8
+ from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components
9
+
10
+ class ProtoMedSAM(nn.Module):
11
+ def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False, coarse_pred_only=False):
12
+ super().__init__()
13
+ if isinstance(image_size, int):
14
+ image_size = (image_size, image_size)
15
+ self.image_size = image_size
16
+ self.coarse_segmentation_model = coarse_segmentation_model
17
+ self.get_sam(sam_pretrained_path)
18
+ self.coarse_pred_only = coarse_pred_only
19
+ self.debug = debug
20
+ self.use_cca = use_cca
21
+
22
+
23
+ def get_sam(self, checkpoint_path):
24
+ model_type="vit_b" # TODO make generic?
25
+ if 'vit_h' in checkpoint_path:
26
+ model_type = "vit_h"
27
+ self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
28
+
29
+
30
+ torch.no_grad()
31
+ def medsam_inference(self, img_embed, box_1024, H, W, query_label=None):
32
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
33
+ if len(box_torch.shape) == 2:
34
+ box_torch = box_torch[:, None, :] # (B, 1, 4)
35
+
36
+ sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder(
37
+ points=None,
38
+ boxes=box_torch,
39
+ masks=None,
40
+ )
41
+ low_res_logits, conf = self.medsam.mask_decoder(
42
+ image_embeddings=img_embed, # (B, 256, 64, 64)
43
+ image_pe=self.medsam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
44
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
45
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
46
+ multimask_output=True if query_label is not None else False,
47
+ )
48
+
49
+ low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
50
+
51
+ low_res_pred = F.interpolate(
52
+ low_res_pred,
53
+ size=(H, W),
54
+ mode="bilinear",
55
+ align_corners=False,
56
+ ) # (1, 1, gt.shape)
57
+ low_res_pred = low_res_pred.squeeze().cpu() # (256, 256)
58
+
59
+ low_res_pred = low_res_pred.numpy()
60
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
61
+
62
+ if query_label is not None:
63
+ medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :]
64
+
65
+ return medsam_seg, conf.cpu().detach().numpy()
66
+
67
+ def get_iou(self, pred, label):
68
+ """
69
+ pred np array shape h,w type uint8
70
+ label np array shpae h,w type uiint8
71
+ """
72
+ tp = np.logical_and(pred, label).sum()
73
+ fp = np.logical_and(pred, 1-label).sum()
74
+ fn = np.logical_and(1-pred, label).sum()
75
+ iou = tp / (tp + fp + fn)
76
+ return iou
77
+
78
+ def get_best_mask(self, masks, labels):
79
+ """
80
+ masks np shape ( B, h, w)
81
+ labels torch shape (1, H, W)
82
+ """
83
+ np_labels = labels[0].clone().detach().cpu().numpy()
84
+ best_iou, best_mask = 0, None
85
+ for mask in masks:
86
+ iou = self.get_iou(mask, np_labels)
87
+ if iou > best_iou:
88
+ best_iou = iou
89
+ best_mask = mask
90
+
91
+ return best_mask
92
+
93
+ def get_bbox(self, pred):
94
+ """
95
+ pred is tensor of shape (H,W) - 1 is fg, 0 is bg.
96
+ return bbox of pred s.t np.array([xmin, y_min, xmax, ymax])
97
+ """
98
+ if isinstance(pred, np.ndarray):
99
+ pred = torch.from_numpy(pred)
100
+ if pred.max() == 0:
101
+ return None
102
+ indices = torch.nonzero(pred)
103
+ ymin, xmin = indices.min(dim=0)[0]
104
+ ymax, xmax = indices.max(dim=0)[0]
105
+ return np.array([xmin, ymin, xmax, ymax])
106
+
107
+
108
+ def get_bbox_per_cc(self, conn_components):
109
+ """
110
+ conn_components: output of cca function
111
+ return list of bboxes per connected component, each bbox is a list of 2d points
112
+ """
113
+ bboxes = []
114
+ for i in range(1, conn_components[0]):
115
+ # get the indices of the foreground points
116
+ pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8)
117
+ bboxes.append(self.get_bbox(pred))
118
+
119
+ bboxes = np.array(bboxes)
120
+ return bboxes
121
+
122
+ def forward(self, query_image, coarse_model_input, degrees_rotate=0):
123
+ """
124
+ query_image: 3d tensor of shape (1, 3, H, W)
125
+ images should be normalized with mean and std but not to [0, 1]?
126
+ """
127
+ original_size = query_image.shape[-2]
128
+ # rotate query_image by degrees_rotate
129
+ rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
130
+ # print(f"rotating query image took {time.time() - start_time} seconds")
131
+ coarse_model_input.set_query_images(rotated_img)
132
+ output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
133
+ # print(f"ALPNet took {time.time() - start_time} seconds")
134
+
135
+ if degrees_rotate != 0:
136
+ output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
137
+ # print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
138
+ else:
139
+ output_logits = output_logits_rot
140
+
141
+ # check if softmax is needed
142
+ # output_p = output_logits.softmax(dim=1)
143
+ output_p = output_logits
144
+ pred = output_logits.argmax(dim=1)[0]
145
+ if self.debug:
146
+ _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
147
+ plt.subplot(132)
148
+ plt.imshow(query_image[0,0].detach().cpu())
149
+ plt.imshow(_pred, alpha=0.5)
150
+ plt.subplot(131)
151
+ # plot heatmap of prob of being fg
152
+ plt.imshow(output_p[0, 1].detach().cpu())
153
+ # plot rotated query image and rotated pred
154
+ output_p_rot = output_logits_rot.softmax(dim=1)
155
+ _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
156
+ _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
157
+ plt.subplot(133)
158
+ plt.imshow(rotated_img[0, 0].detach().cpu())
159
+ plt.imshow(_pred_rot, alpha=0.5)
160
+ plt.savefig('debug/coarse_pred.png')
161
+ plt.close()
162
+
163
+ if self.coarse_pred_only:
164
+ output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
165
+ pred = output_logits.argmax(dim=1)[0]
166
+ conf = get_confidence_from_logits(output_logits)
167
+ if self.use_cca:
168
+ _pred = np.array(pred.detach().cpu())
169
+ _pred, conf = cca(_pred, output_logits, return_conf=True)
170
+ pred = torch.from_numpy(_pred)
171
+ if self.training:
172
+ return output_logits, [conf]
173
+ return pred, [conf]
174
+
175
+ if query_image.shape[-2:] != self.image_size:
176
+ query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
177
+ output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
178
+ if need_softmax(output_logits):
179
+ output_logits = output_logits.softmax(dim=1)
180
+
181
+ output_p = output_logits
182
+ pred = output_p.argmax(dim=1)[0]
183
+
184
+ _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
185
+ if self.use_cca:
186
+ conn_components = cca(_pred, output_logits, return_cc=True)
187
+ conf=None
188
+ else:
189
+ conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
190
+ if self.debug:
191
+ plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
192
+ # print(f"connected components took {time.time() - start_time} seconds")
193
+
194
+ if _pred.max() == 0:
195
+ if output_p.shape[-2:] != original_size:
196
+ output_p = F.interpolate(output_p, size=original_size, mode='bilinear')
197
+ return output_p.argmax(dim=1)[0], [0]
198
+
199
+ H, W = query_image.shape[-2:]
200
+ # bbox = self.get_bbox(_pred)
201
+ bbox = self.get_bbox_per_cc(conn_components)
202
+ bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
203
+ query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
204
+ with torch.no_grad():
205
+ image_embedding = self.medsam.image_encoder(query_image)
206
+
207
+ medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W)
208
+
209
+ if self.debug:
210
+ fig, ax = plt.subplots(1, 2)
211
+ ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
212
+ show_mask(medsam_seg, ax[0])
213
+ ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
214
+ show_box(bbox[0], ax[1])
215
+ plt.savefig('debug/medsam_pred.png')
216
+ plt.close()
217
+
218
+ medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
219
+ if medsam_seg.shape[-2:] != original_size:
220
+ medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
221
+
222
+ return medsam_seg, [conf]
223
+
224
+ def segment_all(self, query_image, query_label):
225
+ H, W = query_image.shape[-2:]
226
+ # bbox = self.get_bbox(_pred)
227
+ # bbox = self.get_bbox_per_cc(conn_components)
228
+ # bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
229
+ bbox = np.array([[0, 0, W, H]])
230
+ query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
231
+ with torch.no_grad():
232
+ image_embedding = self.medsam.image_encoder(query_image)
233
+
234
+ medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label)
235
+
236
+ if self.debug:
237
+ fig, ax = plt.subplots(1, 2)
238
+ ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
239
+ show_mask(medsam_seg, ax[0])
240
+ ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
241
+ show_box(bbox[0], ax[1])
242
+ plt.savefig('debug/medsam_pred.png')
243
+ plt.close()
244
+
245
+ medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
246
+ if medsam_seg.shape[-2:] != (H, W):
247
+ medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0]
248
+
249
+ return medsam_seg.view(H,W), [conf]
250
+
251
+
252
+ def show_mask(mask, ax, random_color=False):
253
+ if random_color:
254
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
255
+ else:
256
+ color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
257
+ h, w = mask.shape[-2:]
258
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
259
+ ax.imshow(mask_image)
260
+
261
+
262
+ def show_box(box, ax):
263
+ x0, y0 = box[0], box[1]
264
+ w, h = box[2] - box[0], box[3] - box[1]
265
+ ax.add_patch(
266
+ plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
267
+ )
models/ProtoSAM.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from models.grid_proto_fewshot import FewShotSeg
8
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
9
+ from models.SamWrapper import SamWrapper
10
+ from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits
11
+ from util.lora import inject_trainable_lora
12
+ from models.segment_anything.utils.transforms import ResizeLongestSide
13
+ import cv2
14
+ import time
15
+ from abc import ABC, abstractmethod
16
+
17
+ CONF_MODE="conf"
18
+ CENTROID_MODE="centroid"
19
+ BOTH_MODE="both"
20
+ POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE)
21
+
22
+ TYPE_ALPNET="alpnet"
23
+ TYPE_SAM="sam"
24
+
25
+ def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"):
26
+ num_labels, labels, stats, centroids = cca_output
27
+ # Create an output image with random colors for each component
28
+ output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8)
29
+ for label in range(1, num_labels): # Start from 1 to skip the background
30
+ mask = labels == label
31
+ output_image[mask] = np.random.randint(0, 255, size=3)
32
+
33
+ # Plotting the original and the colored components image
34
+ plt.figure(figsize=(10, 5))
35
+ plt.subplot(121), plt.imshow(original_image), plt.title('Original Image')
36
+ plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components')
37
+ if confidences is not None:
38
+ # Plot the axes color chart with the confidences, use the same colors as the connected components
39
+ plt.subplot(122)
40
+ scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet')
41
+ plt.colorbar(scatter)
42
+
43
+ plt.savefig(title)
44
+ plt.close()
45
+
46
+ class SegmentationInput(ABC):
47
+ @abstractmethod
48
+ def set_query_images(self, query_images):
49
+ pass
50
+
51
+ def to(self, device):
52
+ pass
53
+
54
+ class SegmentationOutput(ABC):
55
+ @abstractmethod
56
+ def get_prediction(self):
57
+ pass
58
+
59
+ class ALPNetInput(SegmentationInput): # for alpnet
60
+ def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None):
61
+ self.supp_imgs = [support_images]
62
+ self.fore_mask = [support_labels]
63
+ self.back_mask = [[1 - sup_labels for sup_labels in support_labels]]
64
+ self.qry_imgs = [query_images]
65
+ self.isval = isval
66
+ self.val_wsize = val_wsize
67
+ self.show_viz = show_viz
68
+ self.supp_fts = supp_fts
69
+
70
+ def set_query_images(self, query_images):
71
+ self.qry_imgs = [query_images]
72
+
73
+ def to(self, device):
74
+ self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]]
75
+ self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]]
76
+ self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]]
77
+ self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs]
78
+ if self.supp_fts is not None:
79
+ self.supp_fts = self.supp_fts.to(device)
80
+
81
+ class ALPNetOutput(SegmentationOutput):
82
+ def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts):
83
+ self.pred = pred
84
+ self.align_loss = align_loss
85
+ self.sim_maps = sim_maps
86
+ self.assign_maps = assign_maps
87
+ self.proto_grid = proto_grid
88
+ self.supp_fts = supp_fts
89
+ self.qry_fts = qry_fts
90
+
91
+ def get_prediction(self):
92
+ return self.pred
93
+
94
+ class SAMWrapperInput(SegmentationInput):
95
+ def __init__(self, image, image_labels):
96
+ self.image = image
97
+ self.image_labels = image_labels
98
+
99
+ def set_query_images(self, query_images):
100
+ B, C, H, W = query_images.shape
101
+ if isinstance(query_images, torch.Tensor):
102
+ query_images = query_images.cpu().detach().numpy()
103
+ assert B == 1, "batch size must be 1"
104
+ query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255
105
+ query_images = query_images.astype(np.uint8)
106
+ self.image = np.transpose(query_images[0], (1, 2, 0))
107
+
108
+ def to(self, device):
109
+ pass
110
+
111
+
112
+ class InputFactory(ABC):
113
+ @staticmethod
114
+ def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None):
115
+
116
+ if input_type == TYPE_ALPNET:
117
+ return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts)
118
+ elif input_type == TYPE_SAM:
119
+ qimg = np.array(query_image.detach().cpu())
120
+ B,C,H,W = qimg.shape
121
+ assert B == 1, "batch size must be 1"
122
+ gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W)
123
+ assert np.unique(gts).shape[0] <= 2, "support labels must be binary"
124
+ gts[gts > 0] = 1
125
+ qimg = qimg.reshape(H,W,C)
126
+ qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255
127
+ qimg = qimg.astype(np.uint8)
128
+ return SAMWrapperInput(qimg, gts)
129
+ else:
130
+ raise ValueError(f"input_type not supported")
131
+
132
+
133
+ class ModelWrapper(ABC):
134
+ def __init__(self, model):
135
+ self.model = model
136
+
137
+ def __call__(self, input_data: SegmentationInput)->SegmentationOutput:
138
+ pass
139
+
140
+ def state_dict(self):
141
+ return self.model.state_dict()
142
+
143
+ def load_state_dict(self, state_dict):
144
+ self.model.load_state_dict(state_dict)
145
+
146
+ def eval(self):
147
+ self.model.eval()
148
+
149
+ def train(self):
150
+ self.model.train()
151
+
152
+ def parameters(self):
153
+ pass
154
+
155
+ class ALPNetWrapper(ModelWrapper):
156
+ def __init__(self, model: FewShotSeg):
157
+ super().__init__(model)
158
+
159
+ def __call__(self, input_data: ALPNetInput):
160
+ output = self.model(**input_data.__dict__)
161
+ output = ALPNetOutput(*output)
162
+ return output.pred
163
+
164
+ def parameters(self):
165
+ return self.model.encoder.parameters()
166
+
167
+ def train(self):
168
+ self.model.encoder.train()
169
+
170
+ class SamWrapperWrapper(ModelWrapper):
171
+ def __init__(self, model:SamWrapper):
172
+ super().__init__(model)
173
+
174
+ def __call__(self, input_data: SAMWrapperInput):
175
+ pred = self.model(**input_data.__dict__)
176
+ # make pred look like logits
177
+ pred = torch.tensor(pred).float()[None, None, ...]
178
+ pred = torch.cat([1-pred, pred], dim=1)
179
+ return pred
180
+
181
+ def to(self, device):
182
+ self.model.sam.to(device)
183
+
184
+ class ProtoSAM(nn.Module):
185
+ def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ):
186
+ super().__init__()
187
+ if isinstance(image_size, int):
188
+ image_size = (image_size, image_size)
189
+ self.image_size = image_size
190
+ self.coarse_segmentation_model = coarse_segmentation_model
191
+ self.get_sam(sam_pretrained_path, use_sam_trans)
192
+ self.num_points_for_sam = num_points_for_sam
193
+ self.use_points = use_points
194
+ self.use_bbox = use_bbox # if False then uses points
195
+ self.use_mask = use_mask
196
+ self.use_neg_points = use_neg_points
197
+ assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask"
198
+ self.use_cca = use_cca
199
+ self.point_mode = point_mode
200
+ if self.point_mode not in POINT_MODES:
201
+ raise ValueError(f"point mode must be one of {POINT_MODES}")
202
+ self.debug=debug
203
+ self.coarse_pred_only = coarse_pred_only
204
+
205
+ def get_sam(self, checkpoint_path, use_sam_trans):
206
+ model_type="vit_b" # TODO make generic?
207
+ if 'vit_h' in checkpoint_path:
208
+ model_type = "vit_h"
209
+ self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
210
+ self.predictor = SamPredictor(self.sam)
211
+ self.sam.requires_grad_(False)
212
+ if use_sam_trans:
213
+ # sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1])
214
+ sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size)
215
+ sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1)
216
+ sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1)
217
+ else:
218
+ sam_trans = None
219
+
220
+ self.sam_trans = sam_trans
221
+
222
+ def get_bbox(self, pred):
223
+ '''
224
+ pred tensor of shape (H, W) where 1 represents foreground and 0 represents background
225
+ returns a list of 2d points representing the bbox
226
+ '''
227
+ if isinstance(pred, np.ndarray):
228
+ pred = torch.tensor(pred)
229
+ # get the indices of the foreground points
230
+ indices = torch.nonzero(pred)
231
+ # get the min and max of the indices
232
+ min_x = indices[:, 1].min()
233
+ max_x = indices[:, 1].max()
234
+ min_y = indices[:, 0].min()
235
+ max_y = indices[:, 0].max()
236
+ # get the bbox
237
+ bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]]
238
+
239
+
240
+ return bbox
241
+
242
+ def get_bbox_per_cc(self, conn_components):
243
+ """
244
+ conn_components: output of cca function
245
+ return list of bboxes per connected component, each bbox is a list of 2d points
246
+ """
247
+ bboxes = []
248
+ for i in range(1, conn_components[0]):
249
+ # get the indices of the foreground points
250
+ indices = torch.nonzero(torch.tensor(conn_components[1] == i))
251
+ # get the min and max of the indices
252
+ min_x = indices[:, 1].min()
253
+ max_x = indices[:, 1].max()
254
+ min_y = indices[:, 0].min()
255
+ max_y = indices[:, 0].max()
256
+ # get the bbox
257
+ # bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]]
258
+ # bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]]
259
+ # bbox should be in a XYXY format
260
+ bbox = [min_x, min_y, max_x, max_y]
261
+ bboxes.append(bbox)
262
+
263
+ bboxes = np.array(bboxes)
264
+ return bboxes
265
+
266
+ def get_most_conf_points(self, output_p_fg, pred, k):
267
+ '''
268
+ get the k most confident points from pred
269
+ output_p: 3d tensor of shape (H, W)
270
+ pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background
271
+ '''
272
+ # Create a mask where pred is 1
273
+ mask = pred.bool()
274
+
275
+ # Apply the mask to output_p_fg
276
+ masked_output_p_fg = output_p_fg[mask]
277
+ if masked_output_p_fg.numel() == 0:
278
+ return None, None
279
+ # Get the top k probabilities and their indices
280
+ confidences, indices = torch.topk(masked_output_p_fg, k)
281
+
282
+ # Get the locations of the top k points in xy format
283
+ locations = torch.nonzero(mask)[indices]
284
+ # convert locations to xy format
285
+ locations = locations[:, [1, 0]]
286
+ # convert locations to list of lists
287
+ # points = [loc.tolist() for loc in locations]
288
+
289
+ return locations.numpy(), [float(conf.item()) for conf in confidences]
290
+
291
+
292
+ def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None):
293
+ '''
294
+ points: np array of shape (N, 2) where each row is a point in xy format
295
+ pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background
296
+ image: 2d tensor of shape (H,W) representing the image
297
+ bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format
298
+ '''
299
+ warnings.filterwarnings('ignore', category=UserWarning)
300
+ if isinstance(pred, torch.Tensor):
301
+ pred = pred.cpu().detach().numpy()
302
+ if len(image.shape) == 3 and image.shape[0] == 3:
303
+ image = image.permute(1, 2, 0)
304
+ if title is None:
305
+ title="debug/most_conf_points.png"
306
+
307
+ fig = plt.figure()
308
+ image = (image - image.min()) / (image.max() - image.min())
309
+ plt.imshow(image)
310
+ plt.imshow(pred, alpha=0.5)
311
+ for i, point in enumerate(points):
312
+ plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red')
313
+ if confidences is not None:
314
+ plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red')
315
+ # assume points is a list of lists
316
+ if bboxes is not None:
317
+ for bbox in bboxes:
318
+ if bbox is None:
319
+ continue
320
+ bbox = np.array(bbox)
321
+ # plt.scatter(bbox[:, 1], bbox[:, 0], c='red')
322
+ # plot a line connecting the points
323
+ box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]])
324
+ box = np.vstack([box, box[0]])
325
+ plt.plot(box[:, 0], box[:, 1], c='green')
326
+ plt.colorbar()
327
+ fig.savefig(title)
328
+ plt.close(fig)
329
+
330
+ def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None):
331
+ if len(image.shape) == 3:
332
+ image = image.permute(1, 2, 0)
333
+ image = (image - image.min()) / (image.max() - image.min())
334
+ for i, (mask, score) in enumerate(zip(masks, scores)):
335
+ plt.figure(figsize=(10,10))
336
+ plt.imshow(image)
337
+ show_mask(mask, plt.gca())
338
+ if input_point is not None:
339
+ show_points(input_point, input_label, plt.gca())
340
+ if input_box is not None:
341
+ show_box(input_box, plt.gca())
342
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
343
+ # plt.axis('off')
344
+ plt.savefig(f'debug/sam_mask_{i+1}.png')
345
+ plt.close()
346
+ if i > 5:
347
+ break
348
+
349
+ def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1):
350
+ """
351
+ args:
352
+ conn_components: output of cca function
353
+ output_p: 3d tensor of shape (1, 2, H, W)
354
+ get_neg_points: bool, if True then return the negative points
355
+ l: int, number of negative points to get
356
+ """
357
+ sam_input_points = []
358
+ sam_neg_points = []
359
+ fg_p = output_p[0, 1].detach().cpu()
360
+
361
+ if get_neg_points:
362
+ # get global negative points
363
+ bg_p = output_p[0, 0].detach().cpu()
364
+ bg_p[bg_p < 0.95] = 0
365
+ bg_pred = torch.where(bg_p > 0, 1, 0)
366
+ glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1)
367
+ if self.debug:
368
+ # plot the bg_p as a heatmap
369
+ plt.figure()
370
+ plt.imshow(bg_p)
371
+ plt.colorbar()
372
+ plt.savefig('debug/bg_p_heatmap.png')
373
+ plt.close()
374
+
375
+ for i, cc_id in enumerate(np.unique(conn_components[1])):
376
+ # get self.num_points_for_sam most confident points from pred
377
+ if cc_id == 0:
378
+ continue # skip background
379
+ pred = torch.tensor(conn_components[1] == cc_id).float()
380
+
381
+ if self.point_mode == CONF_MODE:
382
+ points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2)
383
+ elif self.point_mode == CENTROID_MODE:
384
+ points = conn_components[3][cc_id][None, :] # (1, 2)
385
+ confidences = [1 for _ in range(len(points))]
386
+ elif self.point_mode == BOTH_MODE:
387
+ points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam)
388
+ point = conn_components[3][cc_id][None, :]
389
+ points = np.vstack([points, point]) # (N+1, 2)
390
+ confidences.append(1)
391
+ else:
392
+ raise NotImplementedError(f"point mode {self.point_mode} not implemented")
393
+ sam_input_points.append(np.array(points))
394
+
395
+ if get_neg_points:
396
+ pred_uint8 = (pred.numpy() * 255).astype(np.uint8)
397
+
398
+ # Dilate the mask to expand it
399
+ kernel_size = 3 # Size of the dilation kernel, adjust accordingly
400
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
401
+ dilation_iterations = 10 # Number of times dilation is applied, adjust as needed
402
+ dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations)
403
+
404
+ # Subtract the original mask from the dilated mask
405
+ # This will give a boundary that is only outside the original mask
406
+ outside_boundary = dilated_mask - pred_uint8
407
+
408
+ # Convert back to torch tensor and normalize
409
+ boundary = torch.tensor(outside_boundary).float() / 255
410
+ try:
411
+ bg_p = output_p[0, 0].detach().cpu()
412
+ neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l)
413
+ except RuntimeError as e:
414
+ # make each point (None, None)
415
+ neg_points = None
416
+ # append global negative points to the negative points
417
+ if neg_points is not None and glob_neg_points is not None:
418
+ neg_points = np.vstack([neg_points, glob_neg_points])
419
+ else:
420
+ neg_points = glob_neg_points if neg_points is None else neg_points
421
+ if self.debug and neg_points is not None:
422
+ # draw an image with 2 subplots, one is the pred and the other is the boundary
423
+ plt.figure()
424
+ plt.subplot(121)
425
+ plt.imshow(pred)
426
+ plt.imshow(boundary, alpha=0.5)
427
+ # plot the neg points
428
+ plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red')
429
+ plt.subplot(122)
430
+ plt.imshow(pred)
431
+ plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red')
432
+ plt.savefig('debug/pred_and_boundary.png')
433
+ plt.close()
434
+ sam_neg_points.append(neg_points)
435
+ else:
436
+ # create a list of None same shape as points
437
+ sam_neg_points = [None for _ in range(len(sam_input_points))]
438
+
439
+ sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))])
440
+ sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2)
441
+ # if get_neg_points:
442
+ sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None
443
+ if sam_neg_input_points is not None:
444
+ sam_neg_input_points = sam_neg_points
445
+ sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) )
446
+ else:
447
+ sam_neg_input_points = None
448
+ sam_neg_input_labels = None
449
+
450
+ return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels
451
+
452
+ def get_sam_input_mask(self, conn_components):
453
+ sam_input_masks = []
454
+ sam_input_mask_lables = []
455
+ for i, cc_id in enumerate(np.unique(conn_components[1])):
456
+ # get self.num_points_for_sam most confident points from pred
457
+ if cc_id == 0:
458
+ continue
459
+ pred = torch.tensor(conn_components[1] == cc_id).float()
460
+ sam_input_masks.append(pred)
461
+ sam_input_mask_lables.append(cc_id)
462
+
463
+ sam_input_masks = np.stack(sam_input_masks)
464
+ sam_input_mask_lables = np.array(sam_input_mask_lables)
465
+
466
+ return sam_input_masks, sam_input_mask_lables
467
+
468
+ def predict_w_masks(self, sam_input_masks, qry_img, original_size):
469
+ masks = []
470
+ scores = []
471
+ for in_mask in sam_input_masks:
472
+ in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST)
473
+ in_mask[in_mask == 1] = 10
474
+ in_mask[in_mask == 0] = -8
475
+ assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8
476
+ self.predictor.set_image(qry_img)
477
+ mask, score, _ = self.predictor.predict(
478
+ mask_input=in_mask[None, ...].astype(np.uint8),
479
+ multimask_output=True)
480
+ # get max index from score
481
+ if self.debug:
482
+ # plot each channel of mask
483
+ fig, ax = plt.subplots(1, 4, figsize=(15, 5))
484
+ for i in range(mask.shape[0]):
485
+ ax[i].imshow(qry_img)
486
+ ax[i].imshow(mask[i], alpha=0.5)
487
+ ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18)
488
+ # ax[i].axis('off')
489
+ ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST))
490
+ fig.savefig(f'debug/sam_mask_from_mask_prompts.png')
491
+ plt.close(fig)
492
+
493
+
494
+ max_index = score.argmax()
495
+ masks.append(mask[max_index])
496
+ scores.append(score[max_index])
497
+
498
+ return masks, scores
499
+
500
+ def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False):
501
+ masks, scores = [], []
502
+ self.predictor.set_image(qry_img)
503
+ # if sam_input_points is None:
504
+ # sam_input_points = [None for _ in range(len(bboxes))]
505
+ for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points):
506
+ assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8
507
+ points = point
508
+ point_labels = np.array([1] * len(point)) if point is not None else None
509
+ if self.use_neg_points:
510
+ neg_points = [npoint for npoint in neg_point if None not in npoint]
511
+ points = np.vstack([point, *neg_points])
512
+ point_labels = np.array([1] * len(point) + [0] * len(neg_points))
513
+ if self.debug:
514
+ self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points
515
+ mask, score, _ = self.predictor.predict(
516
+ point_coords=points,
517
+ point_labels=point_labels,
518
+ # box=bbox_xyxy[None, :] if bbox_xyxy is not None else None,
519
+ box = bbox_xyxy if bbox_xyxy is not None else None,
520
+ # mask_input=sam_mask_input,
521
+ return_logits=return_logits,
522
+ multimask_output=False if self.use_cca else True
523
+ )
524
+ # best_pred_idx = np.argmax(score)
525
+ best_pred_idx = 0
526
+ masks.append(mask[best_pred_idx])
527
+ scores.append(score[best_pred_idx])
528
+
529
+ if self.debug:
530
+ # pass
531
+ self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None)
532
+
533
+ return masks, scores
534
+
535
+
536
+ def forward(self, query_image, coarse_model_input, degrees_rotate=0):
537
+ """
538
+ query_image: 3d tensor of shape (1, 3, H, W)
539
+ images should be normalized with mean and std but not to [0, 1]?
540
+ """
541
+ original_size = query_image.shape[-2]
542
+ # rotate query_image by degrees_rotate
543
+ start_time = time.time()
544
+ rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
545
+ # print(f"rotating query image took {time.time() - start_time} seconds")
546
+ start_time = time.time()
547
+ coarse_model_input.set_query_images(rotated_img)
548
+ output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
549
+ # print(f"ALPNet took {time.time() - start_time} seconds")
550
+
551
+ if degrees_rotate != 0:
552
+ start_time = time.time()
553
+ output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
554
+ # print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
555
+ else:
556
+ output_logits = output_logits_rot
557
+
558
+ # check if softmax is needed
559
+ output_p = output_logits.softmax(dim=1)
560
+ # output_p = output_logits
561
+ pred = output_logits.argmax(dim=1)[0]
562
+ if self.debug:
563
+ _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
564
+ plt.subplot(132)
565
+ plt.imshow(query_image[0,0].detach().cpu())
566
+ plt.imshow(_pred, alpha=0.5)
567
+ plt.subplot(131)
568
+ # plot heatmap of prob of being fg
569
+ plt.imshow(output_p[0, 1].detach().cpu())
570
+ # plot rotated query image and rotated pred
571
+ output_p_rot = output_logits_rot.softmax(dim=1)
572
+ _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
573
+ _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
574
+ plt.subplot(133)
575
+ plt.imshow(rotated_img[0, 0].detach().cpu())
576
+ plt.imshow(_pred_rot, alpha=0.5)
577
+ plt.savefig('debug/coarse_pred.png')
578
+ plt.close()
579
+
580
+ if self.coarse_pred_only:
581
+ output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
582
+ pred = output_logits.argmax(dim=1)[0]
583
+ conf = get_confidence_from_logits(output_logits)
584
+ if self.use_cca:
585
+ _pred = np.array(pred.detach().cpu())
586
+ _pred, conf = cca(_pred, output_logits, return_conf=True)
587
+ pred = torch.from_numpy(_pred)
588
+ if self.training:
589
+ return output_logits, [conf]
590
+ # Ensure pred is a float tensor for consistent visualization
591
+ return pred.float(), [conf]
592
+
593
+ if query_image.shape[-2:] != self.image_size:
594
+ query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
595
+ output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
596
+ # if need_softmax(output_logits):
597
+ # output_logits = output_logits.softmax(dim=1)
598
+
599
+ # output_p = output_logits
600
+ output_p = output_logits.softmax(dim=1)
601
+ pred = output_p.argmax(dim=1)[0]
602
+
603
+ _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
604
+ start_time = time.time()
605
+ if self.use_cca:
606
+ conn_components = cca(_pred, output_logits, return_cc=True)
607
+ conf=None
608
+ else:
609
+ conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
610
+ if self.debug:
611
+ plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
612
+ # print(f"connected components took {time.time() - start_time} seconds")
613
+ if _pred.max() == 0:
614
+ return output_p.argmax(dim=1)[0], [0]
615
+
616
+ # get bbox from pred
617
+ if self.use_bbox:
618
+ start_time = time.time()
619
+ try:
620
+ bboxes = self.get_bbox_per_cc(conn_components)
621
+ except:
622
+ bboxes = [None] * conn_components[0]
623
+ else:
624
+ bboxes = [None] * conn_components[0]
625
+ # print(f"getting bboxes took {time.time() - start_time} seconds")
626
+
627
+
628
+ start_time = time.time()
629
+ if self.use_points:
630
+ sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1)
631
+ else:
632
+ sam_input_points = [None] * conn_components[0]
633
+ sam_input_point_labels = [None] * conn_components[0]
634
+ sam_neg_input_points = [None] * conn_components[0]
635
+ sam_neg_input_labels = [None] * conn_components[0]
636
+ # print(f"getting sam input points took {time.time() - start_time} seconds")
637
+
638
+ if self.use_mask:
639
+ sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components)
640
+ else:
641
+ sam_input_masks = None
642
+ sam_input_mask_labels = None
643
+
644
+ if self.debug and sam_input_points is not None:
645
+ title = f'debug/most_conf_points.png'
646
+ if self.use_cca:
647
+ title = f'debug/most_conf_points_cca.png'
648
+ # convert points to a list where each item is a list of 2 elements in xy format
649
+ self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points
650
+
651
+ # self.sam_trans = None
652
+ if self.sam_trans is None:
653
+ query_image = query_image.permute(1, 2, 0).detach().cpu().numpy()
654
+ else:
655
+ query_image = self.sam_trans.apply_image_torch(query_image[0])
656
+ query_image = self.sam_trans.preprocess(query_image)
657
+ query_image = query_image.permute(1, 2, 0).detach().cpu().numpy()
658
+ # mask = self.sam_trans.preprocess(mask)
659
+
660
+
661
+ query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8)
662
+ if self.use_mask:
663
+ masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size)
664
+
665
+ start_time = time.time()
666
+ if self.use_points or self.use_bbox:
667
+ masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False)
668
+ # print(f"predicting w points/bbox took {time.time() - start_time} seconds")
669
+
670
+ pred = sum(masks)
671
+ if not self.training:
672
+ pred = pred > 0
673
+ pred = torch.tensor(pred).float().to(output_p.device)
674
+
675
+ # pred = torch.tensor(masks[0]).float().cuda()
676
+ # resize pred to the size of the input
677
+ pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
678
+
679
+ return pred, scores
680
+
681
+
682
+ def show_mask(mask, ax, random_color=False):
683
+ if random_color:
684
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
685
+ else:
686
+ color = np.array([30/255, 144/255, 255/255, 0.6])
687
+ h, w = mask.shape[-2:]
688
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
689
+ ax.imshow(mask_image)
690
+
691
+ def show_points(coords, labels, ax, marker_size=375):
692
+ pos_points = coords[labels==1]
693
+ neg_points = coords[labels==0]
694
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
695
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
696
+
697
+ def show_box(box, ax):
698
+ x0, y0 = box[0], box[1]
699
+ w, h = box[2] - box[0], box[3] - box[1]
700
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
701
+
702
+ def need_softmax(tensor, dim=1):
703
+ return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0))
704
+
705
+
706
+
707
+
708
+
models/SamWrapper.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from models.segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
5
+ from models.segment_anything.utils.transforms import ResizeLongestSide
6
+ import cv2
7
+
8
+ def get_iou(mask, label):
9
+ tp = (mask * label).sum()
10
+ fp = (mask * (1-label)).sum()
11
+ fn = ((1-mask) * label).sum()
12
+ iou = tp / (tp + fp + fn)
13
+ return iou
14
+
15
+ class SamWrapper(nn.Module):
16
+ def __init__(self,sam_args):
17
+ """
18
+ sam_args: dict should include the following
19
+ {
20
+ "model_type": "vit_h",
21
+ "sam_checkpoint": "path to checkpoint" pretrained_model/sam_vit_h.pth
22
+ }
23
+ """
24
+ super().__init__()
25
+ self.sam = sam_model_registry[sam_args['model_type']](checkpoint=sam_args['sam_checkpoint'])
26
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam)
27
+ self.transform = ResizeLongestSide(self.sam.image_encoder.img_size)
28
+
29
+ def forward(self, image, image_labels):
30
+ """
31
+ generate masks for a batch of images
32
+ return mask that has the largest iou with the image label
33
+ Args:
34
+ images (np.ndarray): The image to generate masks for, in HWC uint8 format.
35
+ image_labels (np.ndarray): The image labels to generate masks for, in HWC uint8 format. assuming binary labels
36
+ """
37
+ image = self.transform.apply_image(image)
38
+ masks = self.mask_generator.generate(image)
39
+
40
+ best_index, best_iou = None, 0
41
+ for i, mask in enumerate(masks):
42
+ segmentation = mask['segmentation']
43
+ iou = get_iou(segmentation.astype(np.uint8), image_labels)
44
+ if best_index is None or iou > best_iou:
45
+ best_index = i
46
+ best_iou = iou
47
+
48
+ return masks[best_index]['segmentation']
49
+
50
+ def to(self, device):
51
+ self.sam.to(device)
52
+ self.mask_generator.to(device)
53
+ self.mask_generator.predictor.to(device)
54
+
55
+
56
+
57
+ if __name__ == "__main__":
58
+ sam_args = {
59
+ "model_type": "vit_h",
60
+ "sam_checkpoint": "pretrained_model/sam_vit_h.pth"
61
+ }
62
+ sam_wrapper = SamWrapper(sam_args).cuda()
63
+ image = cv2.imread("./Kheops-Pyramid.jpg")
64
+ image = np.array(image).astype('uint8')
65
+ image_labels = torch.rand(1,3,224,224)
66
+ sam_wrapper(image, image_labels)
67
+
68
+
models/__init__.py ADDED
File without changes
models/__pycache__/ProtoSAM.cpython-312.pyc ADDED
Binary file (42.4 kB). View file
 
models/__pycache__/SamWrapper.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
models/__pycache__/alpmodule.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
models/__pycache__/grid_proto_fewshot.cpython-312.pyc ADDED
Binary file (23.4 kB). View file
 
models/alpmodule.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ALPModule
3
+ """
4
+ import torch
5
+ import time
6
+ import math
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ import numpy as np
10
+ from pdb import set_trace
11
+ import matplotlib.pyplot as plt
12
+ # for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm
13
+
14
+ def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
15
+ x_norm = torch.norm(x, p = p, dim = dim) # .detach()
16
+ x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
17
+ x = x.div(x_norm.unsqueeze(1).expand_as(x))
18
+ return x
19
+
20
+
21
+ class MultiProtoAsConv(nn.Module):
22
+ def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'):
23
+ """
24
+ ALPModule
25
+ Args:
26
+ proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2
27
+ feature_hw: Spatial size of input feature map
28
+
29
+ """
30
+ super(MultiProtoAsConv, self).__init__()
31
+ self.feature_hw = feature_hw
32
+ self.proto_grid = proto_grid
33
+ self.upsample_mode = upsample_mode
34
+ kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ]
35
+ self.kernel_size = kernel_size
36
+ print(f"MultiProtoAsConv: kernel_size: {kernel_size}")
37
+ self.avg_pool_op = nn.AvgPool2d( kernel_size )
38
+
39
+ if use_attention:
40
+ self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
41
+ self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
42
+ self.fg_mask_projection = nn.Sequential(
43
+ nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
44
+ nn.ReLU(inplace=True),
45
+ nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
46
+ nn.ReLU(inplace=True),
47
+ nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
48
+ )
49
+ self.bg_mask_projection = nn.Sequential(
50
+ nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
51
+ nn.ReLU(inplace=True),
52
+ nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
53
+ nn.ReLU(inplace=True),
54
+ nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
55
+ )
56
+
57
+ def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ):
58
+ if mode == 'mask':
59
+ pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w]
60
+ # incase there are more than one prototypes in the same location, take the max
61
+ pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0)
62
+ vis_dict = {'proto_assign': pred_mask} # things to visualize
63
+ if vis_sim:
64
+ vis_dict['raw_local_sims'] = pred_mask
65
+ return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w]
66
+
67
+ elif mode == 'gridconv':
68
+ dists = F.conv2d(query, prototypes[..., None, None]) * 20
69
+
70
+ pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
71
+ debug_assign = dists.argmax(dim = 1).float().detach()
72
+
73
+ vis_dict = {'proto_assign': debug_assign} # things to visualize
74
+
75
+ if vis_sim: # return the similarity for visualization
76
+ vis_dict['raw_local_sims'] = dists.clone().detach()
77
+ return pred_grid, [debug_assign], vis_dict
78
+
79
+ elif mode == 'gridconv+':
80
+ dists = F.conv2d(query, prototypes[..., None, None]) * 20
81
+
82
+ pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
83
+ # raw_local_sims = dists.det ach()
84
+
85
+ debug_assign = dists.argmax(dim = 1).float()
86
+
87
+ vis_dict = {'proto_assign': debug_assign}
88
+ if vis_sim:
89
+ vis_dict['raw_local_sims'] = dists.clone().detach()
90
+
91
+ return pred_grid, [debug_assign], vis_dict
92
+
93
+ else:
94
+ raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.")
95
+
96
+
97
+ def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False):
98
+ if mode == 'mask':
99
+ proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
100
+ / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C
101
+
102
+ pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything
103
+ pro_n = proto
104
+ proto_grid = sup_y.clone().detach() # a single prototype for the whole image
105
+ resized_proto_grid = proto_grid
106
+ non_zero = torch.nonzero(proto_grid)
107
+
108
+ elif mode == 'gridconv':
109
+ nch = sup_x.shape[1]
110
+
111
+ sup_nshot = sup_x.shape[0]
112
+ # if len(sup_x.shape) > 4:
113
+ # sup_x = sup_x.squeeze()
114
+ n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
115
+ n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc
116
+ n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
117
+
118
+ sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
119
+
120
+ # get a grid of prototypes
121
+ proto_grid = sup_y_g.clone().detach()
122
+ proto_grid[proto_grid < thresh] = 0
123
+ # interpolate the grid to the original size
124
+ non_zero = torch.nonzero(proto_grid)
125
+
126
+ resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
127
+ for index in non_zero:
128
+ resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
129
+
130
+ sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
131
+ protos = n_sup_x[sup_y_g > thresh, :] # npro, nc
132
+ pro_n = safe_norm(protos)
133
+
134
+ elif mode == 'gridconv+':
135
+ nch = sup_x.shape[1]
136
+ n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
137
+ sup_nshot = sup_x.shape[0]
138
+ n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0)
139
+ n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
140
+ sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
141
+
142
+ # get a grid of prototypes
143
+ proto_grid = sup_y_g.clone().detach()
144
+ proto_grid[proto_grid < thresh] = 0
145
+ non_zero = torch.nonzero(proto_grid)
146
+ for i, idx in enumerate(non_zero):
147
+ proto_grid[0, idx[1], idx[2], idx[3]] = i + 1
148
+ resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
149
+ for index in non_zero:
150
+ resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
151
+
152
+ sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
153
+ protos = n_sup_x[sup_y_g > thresh, :]
154
+
155
+ glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
156
+ / (sup_y.sum(dim=(-1, -2)) + 1e-5)
157
+
158
+ pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 ))
159
+ return pro_n, resized_proto_grid, non_zero
160
+
161
+ def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs):
162
+ """
163
+ Now supports
164
+ Args:
165
+ mode: 'mask'/ 'grid'. if mask, works as original prototyping
166
+ qry: [way(1), nc, h, w]
167
+ sup_x: [nb, nc, h, w]
168
+ sup_y: [nb, 1, h, w]
169
+ vis_sim: visualize raw similarities or not
170
+ New
171
+ mode: 'mask'/ 'grid'. if mask, works as original prototyping
172
+ qry: [way(1), nb(1), nc, h, w]
173
+ sup_x: [way(1), shot, nb(1), nc, h, w]
174
+ sup_y: [way(1), shot, nb(1), h, w]
175
+ vis_sim: visualize raw similarities or not
176
+ """
177
+
178
+ qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w]
179
+ sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w]
180
+ sup_y = sup_y.squeeze(0) # [nshot, 1, h, w]
181
+
182
+ def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
183
+ x_norm = torch.norm(x, p = p, dim = dim) # .detach()
184
+ x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
185
+ x = x.div(x_norm.unsqueeze(1).expand_as(x))
186
+ return x
187
+ if val_wsize is None:
188
+ val_wsize = self.avg_pool_op.kernel_size
189
+ if isinstance(val_wsize, (tuple, list)):
190
+ val_wsize = val_wsize[0]
191
+ sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1])
192
+ pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval)
193
+ if 0 in pro_n.shape:
194
+ print("failed to find prototypes")
195
+ qry_n = qry if mode == 'mask' else safe_norm(qry)
196
+ pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim)
197
+
198
+ return pred_grid, debug_assign, vis_dict, proto_grid
199
+
models/backbone/__init__.py ADDED
File without changes
models/backbone/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (170 Bytes). View file
 
models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc ADDED
Binary file (2.78 kB). View file
 
models/backbone/torchvision_backbones.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Backbones supported by torchvison.
3
+ """
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ import torchvision
11
+
12
+ class TVDeeplabRes101Encoder(nn.Module):
13
+ """
14
+ FCN-Resnet101 backbone from torchvision deeplabv3
15
+ No ASPP is used as we found emperically it hurts performance
16
+ """
17
+ def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False):
18
+ super().__init__()
19
+ _model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None)
20
+ if use_coco_init:
21
+ print("###### NETWORK: Using ms-coco initialization ######")
22
+ else:
23
+ print("###### NETWORK: Training from scratch ######")
24
+
25
+ _model_list = list(_model.children())
26
+ self.aux_dim_keep = aux_dim_keep
27
+ self.backbone = _model_list[0]
28
+ self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension
29
+ self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False)
30
+
31
+ _aspp = _model_list[1][0]
32
+ _conv256 = _model_list[1][1]
33
+ self.aspp_out = nn.Sequential(*[_aspp, _conv256] )
34
+ self.use_aspp = use_aspp
35
+
36
+ def forward(self, x_in, low_level):
37
+ """
38
+ Args:
39
+ low_level: whether returning aggregated low-level features in FCN
40
+ """
41
+ fts = self.backbone(x_in)
42
+ if self.use_aspp:
43
+ fts256 = self.aspp_out(fts['out'])
44
+ high_level_fts = fts256
45
+ else:
46
+ fts2048 = fts['out']
47
+ high_level_fts = self.localconv(fts2048)
48
+
49
+ if low_level:
50
+ low_level_fts = fts['aux'][:, : self.aux_dim_keep]
51
+ return high_level_fts, low_level_fts
52
+ else:
53
+ return high_level_fts
54
+
55
+
56
+
57
+
58
+
models/grid_proto_fewshot.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ALPNet
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from .alpmodule import MultiProtoAsConv
8
+ from .backbone.torchvision_backbones import TVDeeplabRes101Encoder
9
+ from util.consts import DEFAULT_FEATURE_SIZE
10
+ from util.lora import inject_trainable_lora
11
+ # from util.utils import load_config_from_url, plot_dinov2_fts
12
+ import math
13
+
14
+ # Specify a local path to the repository (or use installed package instead)
15
+ FG_PROT_MODE = 'gridconv+' # using both local and global prototype
16
+ # FG_PROT_MODE = 'mask'
17
+ # using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet)
18
+ BG_PROT_MODE = 'gridconv'
19
+
20
+ # thresholds for deciding class of prototypes
21
+ FG_THRESH = 0.95
22
+ BG_THRESH = 0.95
23
+
24
+
25
+ class FewShotSeg(nn.Module):
26
+ """
27
+ ALPNet
28
+ Args:
29
+ in_channels: Number of input channels
30
+ cfg: Model configurations
31
+ """
32
+
33
+ def __init__(self, image_size, pretrained_path=None, cfg=None):
34
+ super(FewShotSeg, self).__init__()
35
+ self.image_size = image_size
36
+ self.pretrained_path = pretrained_path
37
+ print(f'###### Pre-trained path: {self.pretrained_path} ######')
38
+ self.config = cfg or {
39
+ 'align': False, 'debug': False}
40
+ self.get_encoder()
41
+ self.get_cls()
42
+ if self.pretrained_path:
43
+ self.load_state_dict(torch.load(self.pretrained_path), strict=True)
44
+ print(
45
+ f'###### Pre-trained model f{self.pretrained_path} has been loaded ######')
46
+
47
+ def get_encoder(self):
48
+ self.config['feature_hw'] = [DEFAULT_FEATURE_SIZE,
49
+ DEFAULT_FEATURE_SIZE] # default feature map size
50
+ if self.config['which_model'] == 'dlfcn_res101' or self.config['which_model'] == 'default':
51
+ use_coco_init = self.config['use_coco_init']
52
+ self.encoder = TVDeeplabRes101Encoder(use_coco_init)
53
+ self.config['feature_hw'] = [
54
+ math.ceil(self.image_size/8), math.ceil(self.image_size/8)]
55
+ elif self.config['which_model'] == 'dinov2_l14':
56
+ self.encoder = torch.hub.load(
57
+ 'facebookresearch/dinov2', 'dinov2_vitl14')
58
+ self.config['feature_hw'] = [max(
59
+ self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
60
+ elif self.config['which_model'] == 'dinov2_l14_reg':
61
+ try:
62
+ self.encoder = torch.hub.load(
63
+ 'facebookresearch/dinov2', 'dinov2_vitl14_reg')
64
+ except RuntimeError as e:
65
+ self.encoder = torch.hub.load(
66
+ 'facebookresearch/dino', 'dinov2_vitl14_reg', force_reload=True)
67
+ self.config['feature_hw'] = [max(
68
+ self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
69
+ elif self.config['which_model'] == 'dinov2_b14':
70
+ self.encoder = torch.hub.load(
71
+ 'facebookresearch/dinov2', 'dinov2_vitb14')
72
+ self.config['feature_hw'] = [max(
73
+ self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
74
+ else:
75
+ raise NotImplementedError(
76
+ f'Backbone network {self.config["which_model"]} not implemented')
77
+
78
+ if self.config['lora'] > 0:
79
+ self.encoder.requires_grad_(False)
80
+ print(f'Injecting LoRA with rank:{self.config["lora"]}')
81
+ encoder_lora_params = inject_trainable_lora(
82
+ self.encoder, r=self.config['lora'])
83
+
84
+ def get_features(self, imgs_concat):
85
+ if self.config['which_model'] == 'dlfcn_res101':
86
+ img_fts = self.encoder(imgs_concat, low_level=False)
87
+ elif 'dino' in self.config['which_model']:
88
+ # resize imgs_concat to the closest size that is divisble by 14
89
+ imgs_concat = F.interpolate(imgs_concat, size=(
90
+ self.image_size // 14 * 14, self.image_size // 14 * 14), mode='bilinear')
91
+ dino_fts = self.encoder.forward_features(imgs_concat)
92
+ img_fts = dino_fts["x_norm_patchtokens"] # B, HW, C
93
+ img_fts = img_fts.permute(0, 2, 1) # B, C, HW
94
+ C, HW = img_fts.shape[-2:]
95
+ img_fts = img_fts.view(-1, C, int(HW**0.5),
96
+ int(HW**0.5)) # B, C, H, W
97
+ if HW < DEFAULT_FEATURE_SIZE ** 2:
98
+ img_fts = F.interpolate(img_fts, size=(
99
+ DEFAULT_FEATURE_SIZE, DEFAULT_FEATURE_SIZE), mode='bilinear') # this is if h,w < (32,32)
100
+ else:
101
+ raise NotImplementedError(
102
+ f'Backbone network {self.config["which_model"]} not implemented')
103
+
104
+ return img_fts
105
+
106
+ def get_cls(self):
107
+ """
108
+ Obtain the similarity-based classifier
109
+ """
110
+ proto_hw = self.config["proto_grid_size"]
111
+
112
+ if self.config['cls_name'] == 'grid_proto':
113
+ embed_dim = 256
114
+ if 'dinov2_b14' in self.config['which_model']:
115
+ embed_dim = 768
116
+ elif 'dinov2_l14' in self.config['which_model']:
117
+ embed_dim = 1024
118
+ self.cls_unit = MultiProtoAsConv(proto_grid=[proto_hw, proto_hw], feature_hw=self.config["feature_hw"], embed_dim=embed_dim) # when treating it as ordinary prototype
119
+ print(f"cls unit feature hw: {self.cls_unit.feature_hw}")
120
+ else:
121
+ raise NotImplementedError(
122
+ f'Classifier {self.config["cls_name"]} not implemented')
123
+
124
+ def forward_resolutions(self, resolutions, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
125
+ predictions = []
126
+ for res in resolutions:
127
+ supp_imgs_resized = [[F.interpolate(supp_img[0], size=(
128
+ res, res), mode='bilinear') for supp_img in supp_imgs]] if supp_imgs[0][0].shape[-1] != res else supp_imgs
129
+ fore_mask_resized = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
130
+ 0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != res else fore_mask
131
+ back_mask_resized = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
132
+ 0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != res else back_mask
133
+ qry_imgs_resized = [F.interpolate(qry_img, size=(res, res), mode='bilinear')
134
+ for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != res else qry_imgs
135
+
136
+ pred = self.forward(supp_imgs_resized, fore_mask_resized, back_mask_resized,
137
+ qry_imgs_resized, isval, val_wsize, show_viz, supp_fts)[0]
138
+ predictions.append(pred)
139
+
140
+ def resize_inputs_to_image_size(self, supp_imgs, fore_mask, back_mask, qry_imgs):
141
+ supp_imgs = [[F.interpolate(supp_img, size=(
142
+ self.image_size, self.image_size), mode='bilinear') for supp_img in supp_imgs_way] for supp_imgs_way in supp_imgs]
143
+ fore_mask = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
144
+ 0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != self.image_size else fore_mask
145
+ back_mask = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
146
+ 0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != self.image_size else back_mask
147
+ qry_imgs = [F.interpolate(qry_img, size=(self.image_size, self.image_size), mode='bilinear')
148
+ for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != self.image_size else qry_imgs
149
+ return supp_imgs, fore_mask, back_mask, qry_imgs
150
+
151
+ def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
152
+ """
153
+ Args:
154
+ supp_imgs: support images
155
+ way x shot x [B x 3 x H x W], list of lists of tensors
156
+ fore_mask: foreground masks for support images
157
+ way x shot x [B x H x W], list of lists of tensors
158
+ back_mask: background masks for support images
159
+ way x shot x [B x H x W], list of lists of tensors
160
+ qry_imgs: query images
161
+ N x [B x 3 x H x W], list of tensors
162
+ show_viz: return the visualization dictionary
163
+ """
164
+ # ('Please go through this piece of code carefully')
165
+ # supp_imgs, fore_mask, back_mask, qry_imgs = self.resize_inputs_to_image_size(
166
+ # supp_imgs, fore_mask, back_mask, qry_imgs)
167
+
168
+ n_ways = len(supp_imgs)
169
+ n_shots = len(supp_imgs[0])
170
+ n_queries = len(qry_imgs)
171
+
172
+ # NOTE: actual shot in support goes in batch dimension
173
+ assert n_ways == 1, "Multi-shot has not been implemented yet"
174
+ assert n_queries == 1
175
+
176
+ sup_bsize = supp_imgs[0][0].shape[0]
177
+ img_size = supp_imgs[0][0].shape[-2:]
178
+ if self.config["cls_name"] == 'grid_proto_3d':
179
+ img_size = supp_imgs[0][0].shape[-3:]
180
+ qry_bsize = qry_imgs[0].shape[0]
181
+
182
+ imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs]
183
+ + [torch.cat(qry_imgs, dim=0),], dim=0)
184
+
185
+ img_fts = self.get_features(imgs_concat)
186
+ if len(img_fts.shape) == 5: # for 3D
187
+ fts_size = img_fts.shape[-3:]
188
+ else:
189
+ fts_size = img_fts.shape[-2:]
190
+ if supp_fts is None:
191
+ supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view(
192
+ n_ways, n_shots, sup_bsize, -1, *fts_size) # wa x sh x b x c x h' x w'
193
+ qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view(
194
+ n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W'
195
+ else:
196
+ # N x B x C x H' x W'
197
+ qry_fts = img_fts.view(n_queries, qry_bsize, -1, *fts_size)
198
+
199
+ fore_mask = torch.stack([torch.stack(way, dim=0)
200
+ for way in fore_mask], dim=0) # Wa x Sh x B x H' x W'
201
+ fore_mask = torch.autograd.Variable(fore_mask, requires_grad=True)
202
+ back_mask = torch.stack([torch.stack(way, dim=0)
203
+ for way in back_mask], dim=0) # Wa x Sh x B x H' x W'
204
+
205
+ ###### Compute loss ######
206
+ align_loss = 0
207
+ outputs = []
208
+ visualizes = [] # the buffer for visualization
209
+
210
+ for epi in range(1): # batch dimension, fixed to 1
211
+ fg_masks = [] # keep the way part
212
+
213
+ '''
214
+ for way in range(n_ways):
215
+ # note: index of n_ways starts from 0
216
+ mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this
217
+ mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear')
218
+ fg_masks.append( mean_sup_msk )
219
+
220
+ mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W]
221
+ '''
222
+ # re-interpolate support mask to the same size as support feature
223
+ if len(fts_size) == 3: # TODO make more generic
224
+ res_fg_msk = torch.stack([F.interpolate(fore_mask[0][0].unsqueeze(
225
+ 0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
226
+ res_bg_msk = torch.stack([F.interpolate(back_mask[0][0].unsqueeze(
227
+ 0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
228
+ else:
229
+ res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size=fts_size, mode='nearest')
230
+ for fore_mask_w in fore_mask], dim=0) # [nway, ns, nb, nh', nw']
231
+ res_bg_msk = torch.stack([F.interpolate(back_mask_w, size=fts_size, mode='nearest')
232
+ for back_mask_w in back_mask], dim=0) # [nway, ns, nb, nh', nw']
233
+
234
+ scores = []
235
+ assign_maps = []
236
+ bg_sim_maps = []
237
+ fg_sim_maps = []
238
+ bg_mode = BG_PROT_MODE
239
+
240
+ _raw_score, _, aux_attr, _ = self.cls_unit(
241
+ qry_fts, supp_fts, res_bg_msk, mode=bg_mode, thresh=BG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
242
+ scores.append(_raw_score)
243
+ assign_maps.append(aux_attr['proto_assign'])
244
+
245
+ for way, _msks in enumerate(res_fg_msk):
246
+ raw_scores = []
247
+ for i, _msk in enumerate(_msks):
248
+ _msk = _msk.unsqueeze(0)
249
+ supp_ft = supp_fts[:, i].unsqueeze(0)
250
+ if self.config["cls_name"] == 'grid_proto_3d': # 3D
251
+ k_size = self.cls_unit.kernel_size
252
+ fg_mode = FG_PROT_MODE if F.avg_pool3d(_msk, k_size).max(
253
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' # TODO figure out kernel size
254
+ else:
255
+ k_size = self.cls_unit.kernel_size
256
+ fg_mode = FG_PROT_MODE if F.avg_pool2d(_msk, k_size).max(
257
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
258
+ # TODO figure out kernel size
259
+ _raw_score, _, aux_attr, proto_grid = self.cls_unit(qry_fts, supp_ft, _msk.unsqueeze(
260
+ 0), mode=fg_mode, thresh=FG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
261
+ raw_scores.append(_raw_score)
262
+
263
+ # create a score where each feature is the max of the raw_score
264
+ _raw_score = torch.stack(raw_scores, dim=1).max(dim=1)[
265
+ 0]
266
+ scores.append(_raw_score)
267
+ assign_maps.append(aux_attr['proto_assign'])
268
+ if show_viz:
269
+ fg_sim_maps.append(aux_attr['raw_local_sims'])
270
+ # print(f"Time for fg: {time.time() - start_time}")
271
+ pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
272
+ interpolate_mode = 'bilinear'
273
+ outputs.append(F.interpolate(
274
+ pred, size=img_size, mode=interpolate_mode))
275
+
276
+ ###### Prototype alignment loss ######
277
+ if self.config['align'] and self.training:
278
+ align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi],
279
+ fore_mask[:, :, epi], back_mask[:, :, epi])
280
+ align_loss += align_loss_epi
281
+
282
+ output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W
283
+ grid_shape = output.shape[2:]
284
+ if self.config["cls_name"] == 'grid_proto_3d':
285
+ grid_shape = output.shape[2:]
286
+ output = output.view(-1, *grid_shape)
287
+ assign_maps = torch.stack(assign_maps, dim=1) if show_viz else None
288
+ bg_sim_maps = torch.stack(bg_sim_maps, dim=1) if show_viz else None
289
+ fg_sim_maps = torch.stack(fg_sim_maps, dim=1) if show_viz else None
290
+
291
+ return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps, proto_grid, supp_fts, qry_fts
292
+
293
+
294
+ def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask):
295
+ """
296
+ Compute the loss for the prototype alignment branch
297
+
298
+ Args:
299
+ qry_fts: embedding features for query images
300
+ expect shape: N x C x H' x W'
301
+ pred: predicted segmentation score
302
+ expect shape: N x (1 + Wa) x H x W
303
+ supp_fts: embedding fatures for support images
304
+ expect shape: Wa x Sh x C x H' x W'
305
+ fore_mask: foreground masks for support images
306
+ expect shape: way x shot x H x W
307
+ back_mask: background masks for support images
308
+ expect shape: way x shot x H x W
309
+ """
310
+ n_ways, n_shots = len(fore_mask), len(fore_mask[0])
311
+
312
+ # Masks for getting query prototype
313
+ pred_mask = pred.argmax(dim=1).unsqueeze(0) # 1 x N x H' x W'
314
+ binary_masks = [pred_mask == i for i in range(1 + n_ways)]
315
+
316
+ # skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
317
+ # FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness
318
+ skip_ways = []
319
+
320
+ # added for matching dimensions to the new data format
321
+ qry_fts = qry_fts.unsqueeze(0).unsqueeze(
322
+ 2) # added to nway(1) and nb(1)
323
+ # end of added part
324
+
325
+ loss = []
326
+ for way in range(n_ways):
327
+ if way in skip_ways:
328
+ continue
329
+ # Get the query prototypes
330
+ for shot in range(n_shots):
331
+ # actual local query [way(1), nb(1, nb is now nshot), nc, h, w]
332
+ img_fts = supp_fts[way: way + 1, shot: shot + 1]
333
+ size = img_fts.shape[-2:]
334
+ mode = 'bilinear'
335
+ if self.config["cls_name"] == 'grid_proto_3d':
336
+ size = img_fts.shape[-3:]
337
+ mode = 'trilinear'
338
+ qry_pred_fg_msk = F.interpolate(
339
+ binary_masks[way + 1].float(), size=size, mode=mode) # [1 (way), n (shot), h, w]
340
+
341
+ # background
342
+ qry_pred_bg_msk = F.interpolate(
343
+ binary_masks[0].float(), size=size, mode=mode) # 1, n, h ,w
344
+ scores = []
345
+
346
+ bg_mode = BG_PROT_MODE
347
+ _raw_score_bg, _, _, _ = self.cls_unit(
348
+ qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_bg_msk.unsqueeze(-3), mode=bg_mode, thresh=BG_THRESH)
349
+
350
+ scores.append(_raw_score_bg)
351
+ if self.config["cls_name"] == 'grid_proto_3d':
352
+ fg_mode = FG_PROT_MODE if F.avg_pool3d(qry_pred_fg_msk, 4).max(
353
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
354
+ else:
355
+ fg_mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max(
356
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
357
+ _raw_score_fg, _, _, _ = self.cls_unit(
358
+ qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_fg_msk.unsqueeze(2), mode=fg_mode, thresh=FG_THRESH)
359
+ scores.append(_raw_score_fg)
360
+
361
+ supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
362
+ size = fore_mask.shape[-2:]
363
+ if self.config["cls_name"] == 'grid_proto_3d':
364
+ size = fore_mask.shape[-3:]
365
+ supp_pred = F.interpolate(supp_pred, size=size, mode=mode)
366
+
367
+ # Construct the support Ground-Truth segmentation
368
+ supp_label = torch.full_like(fore_mask[way, shot], 255,
369
+ device=img_fts.device).long()
370
+ supp_label[fore_mask[way, shot] == 1] = 1
371
+ supp_label[back_mask[way, shot] == 1] = 0
372
+ # Compute Loss
373
+ loss.append(F.cross_entropy(
374
+ supp_pred.float(), supp_label[None, ...], ignore_index=255) / n_shots / n_ways)
375
+
376
+ return torch.sum(torch.stack(loss))
377
+
378
+ def dino_cls_loss(self, teacher_cls_tokens, student_cls_tokens):
379
+ cls_loss_weight = 0.1
380
+ student_temp = 1
381
+ teacher_cls_tokens = self.sinkhorn_knopp_teacher(teacher_cls_tokens)
382
+ lsm = F.log_softmax(student_cls_tokens / student_temp, dim=-1)
383
+ cls_loss = torch.sum(teacher_cls_tokens * lsm, dim=-1)
384
+
385
+ return -cls_loss.mean() * cls_loss_weight
386
+
387
+ @torch.no_grad()
388
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp=1, n_iterations=3):
389
+ teacher_output = teacher_output.float()
390
+ # world_size = dist.get_world_size() if dist.is_initialized() else 1
391
+ # Q is K-by-B for consistency with notations from our paper
392
+ Q = torch.exp(teacher_output / teacher_temp).t()
393
+ # B = Q.shape[1] * world_size # number of samples to assign
394
+ B = Q.shape[1]
395
+ K = Q.shape[0] # how many prototypes
396
+
397
+ # make the matrix sums to 1
398
+ sum_Q = torch.sum(Q)
399
+ Q /= sum_Q
400
+
401
+ for it in range(n_iterations):
402
+ # normalize each row: total weight per prototype must be 1/K
403
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
404
+ Q /= sum_of_rows
405
+ Q /= K
406
+
407
+ # normalize each column: total weight per sample must be 1/B
408
+ Q /= torch.sum(Q, dim=0, keepdim=True)
409
+ Q /= B
410
+
411
+ Q *= B # the columns must sum to 1 so that Q is an assignment
412
+ return Q.t()
413
+
414
+ def dino_patch_loss(self, features, masked_features, masks):
415
+ # for both supp and query features perform the patch wise loss
416
+ loss = 0.0
417
+ weight = 0.1
418
+ B = features.shape[0]
419
+ for (f, mf, mask) in zip(features, masked_features, masks):
420
+ # TODO sinkhorn knopp center features
421
+ f = f[mask]
422
+ f = self.sinkhorn_knopp_teacher(f)
423
+ mf = mf[mask]
424
+ loss += torch.sum(f * F.log_softmax(mf / 1,
425
+ dim=-1), dim=-1) / mask.sum()
426
+
427
+ return -loss.sum() * weight / B
models/segment_anything/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .predictor import SamPredictor
15
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
models/segment_anything/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (451 Bytes). View file
 
models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc ADDED
Binary file (16.9 kB). View file
 
models/segment_anything/__pycache__/build_sam.cpython-312.pyc ADDED
Binary file (2.77 kB). View file
 
models/segment_anything/__pycache__/predictor.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
models/segment_anything/automatic_mask_generator.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
+
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
+
34
+
35
+ class SamAutomaticMaskGenerator:
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ points_per_side: Optional[int] = 32,
40
+ points_per_batch: int = 64,
41
+ pred_iou_thresh: float = 0.88,
42
+ stability_score_thresh: float = 0.95,
43
+ stability_score_offset: float = 1.0,
44
+ box_nms_thresh: float = 0.7,
45
+ crop_n_layers: int = 0,
46
+ crop_nms_thresh: float = 0.7,
47
+ crop_overlap_ratio: float = 512 / 1500,
48
+ crop_n_points_downscale_factor: int = 1,
49
+ point_grids: Optional[List[np.ndarray]] = None,
50
+ min_mask_region_area: int = 0,
51
+ output_mode: str = "binary_mask",
52
+ custom_points: bool = "false",
53
+ ) -> None:
54
+ """
55
+ Using a SAM model, generates masks for the entire image.
56
+ Generates a grid of point prompts over the image, then filters
57
+ low quality and duplicate masks. The default settings are chosen
58
+ for SAM with a ViT-H backbone.
59
+
60
+ Arguments:
61
+ model (Sam): The SAM model to use for mask prediction.
62
+ points_per_side (int or None): The number of points to be sampled
63
+ along one side of the image. The total number of points is
64
+ points_per_side**2. If None, 'point_grids' must provide explicit
65
+ point sampling.
66
+ points_per_batch (int): Sets the number of points run simultaneously
67
+ by the model. Higher numbers may be faster but use more GPU memory.
68
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
69
+ model's predicted mask quality.
70
+ stability_score_thresh (float): A filtering threshold in [0,1], using
71
+ the stability of the mask under changes to the cutoff used to binarize
72
+ the model's mask predictions.
73
+ stability_score_offset (float): The amount to shift the cutoff when
74
+ calculated the stability score.
75
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
76
+ suppression to filter duplicate masks.
77
+ crop_n_layers (int): If >0, mask prediction will be run again on
78
+ crops of the image. Sets the number of layers to run, where each
79
+ layer has 2**i_layer number of image crops.
80
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
81
+ suppression to filter duplicate masks between different crops.
82
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
83
+ In the first crop layer, crops will overlap by this fraction of
84
+ the image length. Later layers with more crops scale down this overlap.
85
+ crop_n_points_downscale_factor (int): The number of points-per-side
86
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
87
+ point_grids (list(np.ndarray) or None): A list over explicit grids
88
+ of points used for sampling, normalized to [0,1]. The nth grid in the
89
+ list is used in the nth crop layer. Exclusive with points_per_side.
90
+ min_mask_region_area (int): If >0, postprocessing will be applied
91
+ to remove disconnected regions and holes in masks with area smaller
92
+ than min_mask_region_area. Requires opencv.
93
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
94
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
95
+ For large resolutions, 'binary_mask' may consume large amounts of
96
+ memory.
97
+ """
98
+
99
+ assert (points_per_side is None) != (
100
+ point_grids is None
101
+ ), "Exactly one of points_per_side or point_grid must be provided."
102
+ if points_per_side is not None:
103
+ self.point_grids = build_all_layer_point_grids(
104
+ points_per_side,
105
+ crop_n_layers,
106
+ crop_n_points_downscale_factor,
107
+ )
108
+ elif point_grids is not None:
109
+ self.point_grids = point_grids
110
+ else:
111
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
112
+
113
+ assert output_mode in [
114
+ "binary_mask",
115
+ "uncompressed_rle",
116
+ "coco_rle",
117
+ ], f"Unknown output_mode {output_mode}."
118
+ if output_mode == "coco_rle":
119
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
120
+
121
+ if min_mask_region_area > 0:
122
+ import cv2 # type: ignore # noqa: F401
123
+
124
+ self.predictor = SamPredictor(model)
125
+ self.points_per_batch = points_per_batch
126
+ self.pred_iou_thresh = pred_iou_thresh
127
+ self.stability_score_thresh = stability_score_thresh
128
+ self.stability_score_offset = stability_score_offset
129
+ self.box_nms_thresh = box_nms_thresh
130
+ self.crop_n_layers = crop_n_layers
131
+ self.crop_nms_thresh = crop_nms_thresh
132
+ self.crop_overlap_ratio = crop_overlap_ratio
133
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
134
+ self.min_mask_region_area = min_mask_region_area
135
+ self.output_mode = output_mode
136
+ self.custom_points = custom_points
137
+
138
+ @torch.no_grad()
139
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
140
+ """
141
+ Generates masks for the given image.
142
+
143
+ Arguments:
144
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
145
+
146
+ Returns:
147
+ list(dict(str, any)): A list over records for masks. Each record is
148
+ a dict containing the following keys:
149
+ segmentation (dict(str, any) or np.ndarray): The mask. If
150
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
151
+ is a dictionary containing the RLE.
152
+ bbox (list(float)): The box around the mask, in XYWH format.
153
+ area (int): The area in pixels of the mask.
154
+ predicted_iou (float): The model's own prediction of the mask's
155
+ quality. This is filtered by the pred_iou_thresh parameter.
156
+ point_coords (list(list(float))): The point coordinates input
157
+ to the model to generate this mask.
158
+ stability_score (float): A measure of the mask's quality. This
159
+ is filtered on using the stability_score_thresh parameter.
160
+ crop_box (list(float)): The crop of the image used to generate
161
+ the mask, given in XYWH format.
162
+ """
163
+
164
+ # Generate masks
165
+ mask_data = self._generate_masks(image)
166
+
167
+ # Filter small disconnected regions and holes in masks
168
+ if self.min_mask_region_area > 0:
169
+ mask_data = self.postprocess_small_regions(
170
+ mask_data,
171
+ self.min_mask_region_area,
172
+ max(self.box_nms_thresh, self.crop_nms_thresh),
173
+ )
174
+
175
+ # Encode masks
176
+ if self.output_mode == "coco_rle":
177
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
178
+ elif self.output_mode == "binary_mask":
179
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
180
+ else:
181
+ mask_data["segmentations"] = mask_data["rles"]
182
+
183
+ # Write mask records
184
+ curr_anns = []
185
+ for idx in range(len(mask_data["segmentations"])):
186
+ ann = {
187
+ "segmentation": mask_data["segmentations"][idx],
188
+ "area": area_from_rle(mask_data["rles"][idx]),
189
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
190
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
191
+ "point_coords": [mask_data["points"][idx].tolist()],
192
+ "stability_score": mask_data["stability_score"][idx].item(),
193
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
194
+ }
195
+ curr_anns.append(ann)
196
+
197
+ return curr_anns
198
+
199
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
200
+ orig_size = image.shape[:2]
201
+ crop_boxes, layer_idxs = generate_crop_boxes(
202
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
203
+ )
204
+
205
+ # Iterate over image crops
206
+ data = MaskData()
207
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
208
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
209
+ data.cat(crop_data)
210
+
211
+ # Remove duplicate masks between crops
212
+ if len(crop_boxes) > 1:
213
+ # Prefer masks from smaller crops
214
+ scores = 1 / box_area(data["crop_boxes"])
215
+ scores = scores.to(data["boxes"].device)
216
+ keep_by_nms = batched_nms(
217
+ data["boxes"].float(),
218
+ scores,
219
+ torch.zeros_like(data["boxes"][:, 0]), # categories
220
+ iou_threshold=self.crop_nms_thresh,
221
+ )
222
+ data.filter(keep_by_nms)
223
+
224
+ data.to_numpy()
225
+ return data
226
+
227
+ def _process_crop(
228
+ self,
229
+ image: np.ndarray,
230
+ crop_box: List[int],
231
+ crop_layer_idx: int,
232
+ orig_size: Tuple[int, ...],
233
+ ) -> MaskData:
234
+ # Crop the image and calculate embeddings
235
+ x0, y0, x1, y1 = crop_box
236
+ cropped_im = image[y0:y1, x0:x1, :]
237
+ cropped_im_size = cropped_im.shape[:2]
238
+ self.predictor.set_image(cropped_im)
239
+
240
+ # Get points for this crop
241
+ points_scale = np.array(cropped_im_size)[None, ::-1]
242
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
243
+
244
+ # Generate masks for this crop in batches
245
+ data = MaskData()
246
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
247
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
248
+ data.cat(batch_data)
249
+ del batch_data
250
+ self.predictor.reset_image()
251
+
252
+ # Remove duplicates within this crop.
253
+ keep_by_nms = batched_nms(
254
+ data["boxes"].float(),
255
+ data["iou_preds"],
256
+ torch.zeros_like(data["boxes"][:, 0]), # categories
257
+ iou_threshold=self.box_nms_thresh,
258
+ )
259
+ data.filter(keep_by_nms)
260
+
261
+ # Return to the original image frame
262
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
263
+ data["points"] = uncrop_points(data["points"], crop_box)
264
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
265
+
266
+ return data
267
+
268
+ def _process_batch(
269
+ self,
270
+ points: np.ndarray,
271
+ im_size: Tuple[int, ...],
272
+ crop_box: List[int],
273
+ orig_size: Tuple[int, ...],
274
+ ) -> MaskData:
275
+ orig_h, orig_w = orig_size
276
+
277
+ # Run model on this batch
278
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
279
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
280
+ if self.custom_points:
281
+ in_pos_labels = torch.ones(in_points.shape[0]//2, dtype=torch.int, device=in_points.device)
282
+ in_neg_labels = torch.zeros_like(in_pos_labels)
283
+ in_labels = torch.cat((in_pos_labels, in_neg_labels), dim = 0)
284
+ else:
285
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
286
+
287
+ masks, iou_preds, _ = self.predictor.predict_torch(
288
+ in_points[:, None, :],
289
+ in_labels[:, None],
290
+ multimask_output=True,
291
+ return_logits=True,
292
+ )
293
+
294
+ # Serialize predictions and store in MaskData
295
+ data = MaskData(
296
+ masks=masks.flatten(0, 1),
297
+ iou_preds=iou_preds.flatten(0, 1),
298
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
299
+ )
300
+ del masks
301
+
302
+ # Filter by predicted IoU
303
+ if self.pred_iou_thresh > 0.0:
304
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
305
+ data.filter(keep_mask)
306
+
307
+ # Calculate stability score
308
+ data["stability_score"] = calculate_stability_score(
309
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
310
+ )
311
+ if self.stability_score_thresh > 0.0:
312
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
313
+ data.filter(keep_mask)
314
+
315
+ # Threshold masks and calculate boxes
316
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
317
+ data["boxes"] = batched_mask_to_box(data["masks"])
318
+
319
+ # Filter boxes that touch crop boundaries
320
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
321
+ if not torch.all(keep_mask):
322
+ data.filter(keep_mask)
323
+
324
+ # Compress to RLE
325
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
326
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
327
+ del data["masks"]
328
+
329
+ return data
330
+
331
+ @staticmethod
332
+ def postprocess_small_regions(
333
+ mask_data: MaskData, min_area: int, nms_thresh: float
334
+ ) -> MaskData:
335
+ """
336
+ Removes small disconnected regions and holes in masks, then reruns
337
+ box NMS to remove any new duplicates.
338
+
339
+ Edits mask_data in place.
340
+
341
+ Requires open-cv as a dependency.
342
+ """
343
+ if len(mask_data["rles"]) == 0:
344
+ return mask_data
345
+
346
+ # Filter small disconnected regions and holes
347
+ new_masks = []
348
+ scores = []
349
+ for rle in mask_data["rles"]:
350
+ mask = rle_to_mask(rle)
351
+
352
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
353
+ unchanged = not changed
354
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
355
+ unchanged = unchanged and not changed
356
+
357
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
358
+ # Give score=0 to changed masks and score=1 to unchanged masks
359
+ # so NMS will prefer ones that didn't need postprocessing
360
+ scores.append(float(unchanged))
361
+
362
+ # Recalculate boxes and remove any new duplicates
363
+ masks = torch.cat(new_masks, dim=0)
364
+ boxes = batched_mask_to_box(masks)
365
+ keep_by_nms = batched_nms(
366
+ boxes.float(),
367
+ torch.as_tensor(scores),
368
+ torch.zeros_like(boxes[:, 0]), # categories
369
+ iou_threshold=nms_thresh,
370
+ )
371
+
372
+ # Only recalculate RLEs for masks that have changed
373
+ for i_mask in keep_by_nms:
374
+ if scores[i_mask] == 0.0:
375
+ mask_torch = masks[i_mask].unsqueeze(0)
376
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
377
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
378
+ mask_data.filter(keep_by_nms)
379
+
380
+ return mask_data
models/segment_anything/build_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, SamBatched
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ )
22
+
23
+
24
+ build_sam = build_sam_vit_h
25
+
26
+
27
+ def build_sam_vit_l(checkpoint=None):
28
+ return _build_sam(
29
+ encoder_embed_dim=1024,
30
+ encoder_depth=24,
31
+ encoder_num_heads=16,
32
+ encoder_global_attn_indexes=[5, 11, 17, 23],
33
+ checkpoint=checkpoint,
34
+ )
35
+
36
+
37
+ def build_sam_vit_b(checkpoint=None):
38
+ return _build_sam(
39
+ encoder_embed_dim=768,
40
+ encoder_depth=12,
41
+ encoder_num_heads=12,
42
+ encoder_global_attn_indexes=[2, 5, 8, 11],
43
+ checkpoint=checkpoint,
44
+ )
45
+
46
+
47
+ sam_model_registry = {
48
+ "default": build_sam_vit_h,
49
+ "vit_h": build_sam_vit_h,
50
+ "vit_l": build_sam_vit_l,
51
+ "vit_b": build_sam_vit_b,
52
+ }
53
+
54
+
55
+ def _build_sam(
56
+ encoder_embed_dim,
57
+ encoder_depth,
58
+ encoder_num_heads,
59
+ encoder_global_attn_indexes,
60
+ checkpoint=None,
61
+ ):
62
+ prompt_embed_dim = 256
63
+ image_size = 1024
64
+ vit_patch_size = 16
65
+ image_embedding_size = image_size // vit_patch_size
66
+ sam = SamBatched(
67
+ image_encoder=ImageEncoderViT(
68
+ depth=encoder_depth,
69
+ embed_dim=encoder_embed_dim,
70
+ img_size=image_size,
71
+ mlp_ratio=4,
72
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73
+ num_heads=encoder_num_heads,
74
+ patch_size=vit_patch_size,
75
+ qkv_bias=True,
76
+ use_rel_pos=True,
77
+ global_attn_indexes=encoder_global_attn_indexes,
78
+ window_size=14,
79
+ out_chans=prompt_embed_dim,
80
+ ),
81
+ prompt_encoder=PromptEncoder(
82
+ embed_dim=prompt_embed_dim,
83
+ image_embedding_size=(image_embedding_size, image_embedding_size),
84
+ input_image_size=(image_size, image_size),
85
+ mask_in_chans=16,
86
+ ),
87
+ mask_decoder=MaskDecoder(
88
+ num_multimask_outputs=3,
89
+ transformer=TwoWayTransformer(
90
+ depth=2,
91
+ embedding_dim=prompt_embed_dim,
92
+ mlp_dim=2048,
93
+ num_heads=8,
94
+ ),
95
+ transformer_dim=prompt_embed_dim,
96
+ iou_head_depth=3,
97
+ iou_head_hidden_dim=256,
98
+ ),
99
+ pixel_mean=[123.675, 116.28, 103.53],
100
+ pixel_std=[58.395, 57.12, 57.375],
101
+ )
102
+ sam.eval()
103
+ if checkpoint is not None:
104
+ with open(checkpoint, "rb") as f:
105
+ state_dict = torch.load(f)
106
+ sam.load_state_dict(state_dict)
107
+ return sam
models/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam, SamBatched
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (454 Bytes). View file
 
models/segment_anything/modeling/__pycache__/common.cpython-312.pyc ADDED
Binary file (2.91 kB). View file
 
models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc ADDED
Binary file (19.3 kB). View file
 
models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc ADDED
Binary file (8.53 kB). View file
 
models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc ADDED
Binary file (12.3 kB). View file
 
models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc ADDED
Binary file (12.3 kB). View file