diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..3e88c65dff890b23ddb559b362dc65682456f201
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+__pycache__/
+exps/
+.vscode/
+debug/
+test_*/
+pretrained_model/*
\ No newline at end of file
diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem
new file mode 100644
index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3
--- /dev/null
+++ b/.gradio/certificate.pem
@@ -0,0 +1,31 @@
+-----BEGIN CERTIFICATE-----
+MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
+TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
+cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
+WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
+ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
+MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
+h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
+0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
+A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
+T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
+B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
+B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
+KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
+OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
+jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
+qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
+rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
+HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
+hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
+ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
+3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
+NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
+ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
+TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
+jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
+oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
+4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
+mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
+emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
+-----END CERTIFICATE-----
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..3877ae0a7ff6f94ac222fd704e112723db776114
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/README.md b/README.md
index 6e887e497ce53f89ee0d3d506fdd06fae406ead7..53b5017d8040fe9d2361cc9f60f6cf435d94a5dd 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,157 @@
----
-title: LoGoSAM Demo
-emoji: 📈
-colorFrom: purple
-colorTo: yellow
-sdk: gradio
-sdk_version: 5.30.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+---
+title: LoGoSAM_demo
+app_file: app.py
+sdk: gradio
+sdk_version: 5.29.0
+---
+# ProtoSAM - One shot segmentation with foundational models
+
+Link to our paper [here](https://arxiv.org/abs/2407.07042). \
+This work is the successor of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning) (Link to [Paper](arxiv.org/abs/2403.03273)).
+
+## Demo Application
+
+A Gradio-based demo application is now available for interactive inference with ProtoSAM. You can upload your own images and masks to test the model. See [README_DEMO.md](README_DEMO.md) for instructions on running the demo.
+
+## Abstract
+This work introduces a new framework, ProtoSAM, for one-shot image segmentation. It combines DINOv2, a vision transformer that extracts features from images, with an Adaptive Local Prototype Pooling (ALP) layer, which generates prototypes from a support image and its mask. These prototypes are used to create an initial coarse segmentation mask by comparing the query image's features with the prototypes.
+Following the extraction of an initial mask, we use numerical methods to generate prompts, such as points and bounding boxes, which are then input into the Segment Anything Model (SAM), a prompt-based segmentation model trained on natural images. This allows segmenting new classes automatically and effectively without the need for additional training.
+
+## How To Run
+### 1. Data preprocessing
+#### 1.1 CT and MRI Dataset
+Please see the notebook `data/data_processing.ipynb` for instructions.
+For convenience i've compiled the data processing instructions from https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation to a single notebook. \
+The CT dataset is available here: https://www.synapse.org/Synapse:syn3553734 \
+The MRI dataset is availabel here: https://chaos.grand-challenge.org
+
+run `./data/CHAOST2/dcm_img_to_nii.sh` to convert dicom images to nifti files.
+
+#### 1.2 Polyp Dataset
+Data is available here: https://www.kaggle.com/datasets/hngphmv/polypdataset?select=train.csv
+
+Put the dataset `data/PolypDataset/`
+
+### 2. Running
+#### 2.1 (Optional) Training and Validation of the coarse segmentation networks
+```
+./backbone.sh [MODE] [MODALITY] [LABEL_SET]
+```
+MODE - validation or training \
+MODALITY - ct or mri \
+LABEL_SET - 0 (kidneys), 1 (liver spleen)
+
+for example:
+```
+./backbone.sh training mri 1
+```
+Please refer to `backbone.sh` for further configurations.
+
+#### 2.1 Running ProtoSAM
+Put all SAM checkpoint like sam_vit_b.pth, sam_vit_h.pth, medsam_vit_b.pth into the `pretrained_model` directory. \
+Checkpoints are available at [SAM](https://github.com/facebookresearch/segment-anything) and [MedSAM](https://github.com/bowang-lab/MedSAM).
+
+```
+./run_protosam.sh [MODALITY] [LABEL_SET]
+```
+MODALITY - ct, mri or polyp \
+LABEL_SET (only relevant if doing ct or mri) - 0 (kidneys), 1 (liver spleen)
+Please refer to the `run_protosam.sh` script for further configurations.
+
+
+## Acknowledgements
+This work is largely based on [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation), [DINOv2](https://github.com/facebookresearch/dinov2), [SAM](https://github.com/facebookresearch/segment-anything) and is a continuation of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning).
+
+## Cite
+If you found this repo useful, please consider giving us a citation and a star!
+
+```bibtex
+@article{ayzenberg2024protosam,
+ title={ProtoSAM-One Shot Medical Image Segmentation With Foundational Models},
+ author={Ayzenberg, Lev and Giryes, Raja and Greenspan, Hayit},
+ journal={arXiv preprint arXiv:2407.07042},
+ year={2024}
+}
+
+@misc{ayzenberg2024dinov2,
+ title={DINOv2 based Self Supervised Learning For Few Shot Medical Image Segmentation},
+ author={Lev Ayzenberg and Raja Giryes and Hayit Greenspan},
+ year={2024},
+ eprint={2403.03273},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+
+```
+
+# ProtoSAM Segmentation Demo
+
+This Streamlit application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction.
+
+## Requirements
+
+- Python 3.8 or higher
+- CUDA-compatible GPU
+- Required Python packages (see `requirements.txt`)
+
+## Setup Instructions
+
+1. Clone this repository:
+```bash
+git clone
+cd
+```
+
+2. Create and activate a virtual environment (optional but recommended):
+```bash
+python -m venv venv
+source venv/bin/activate # On Windows: venv\Scripts\activate
+```
+
+3. Install the required dependencies:
+```bash
+pip install -r requirements.txt
+```
+
+4. Download the pretrained models:
+```bash
+mkdir -p pretrained_model
+# Download SAM ViT-H model
+wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
+mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
+```
+
+5. Update the model path in `app.py`:
+ - Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model.
+
+## Running the App
+
+Start the Streamlit app with:
+```bash
+streamlit run app.py
+```
+
+This will open a browser window with the interface for the segmentation demo.
+
+## Usage
+
+1. Upload a query image (the image you want to segment)
+2. Upload a support image (an example image with a similar object)
+3. Upload a support mask (the segmentation mask for the support image)
+4. Use the sidebar to configure the model parameters if needed
+5. Click "Run Inference" to generate the segmentation result
+
+## Model Configuration
+
+The app allows you to configure several model parameters via the sidebar:
+- Use Bounding Box: Enable/disable bounding box input
+- Use Points: Enable/disable point input
+- Use Mask: Enable/disable mask input
+- Use CCA: Enable/disable Connected Component Analysis
+- Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement
+
+## Notes
+
+- This demo requires a GPU with CUDA support
+- Large images may require more GPU memory
+- For optimal results, use high-quality support images and masks
diff --git a/README_DEMO.md b/README_DEMO.md
new file mode 100644
index 0000000000000000000000000000000000000000..d11c2f105a573ad9df1acd14c4fc14d396bb4a95
--- /dev/null
+++ b/README_DEMO.md
@@ -0,0 +1,76 @@
+# ProtoSAM Segmentation Demo
+
+This Gradio application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction.
+
+## Requirements
+
+- Python 3.8 or higher
+- CUDA-compatible GPU
+- Required Python packages (see `requirements.txt`)
+
+## Setup Instructions
+
+1. Clone this repository:
+```bash
+git clone
+cd
+```
+
+2. Create and activate a virtual environment (optional but recommended):
+```bash
+python -m venv venv
+source venv/bin/activate # On Windows: venv\Scripts\activate
+```
+
+3. Install the required dependencies:
+```bash
+pip install -r requirements.txt
+```
+
+4. Download the pretrained models:
+```bash
+mkdir -p pretrained_model
+# Download SAM ViT-H model
+wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
+mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
+```
+
+5. Update the model path in `app.py`:
+ - Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model.
+
+## Running the App
+
+Start the app with:
+```bash
+./run_demo.sh
+```
+
+Or run it directly with:
+```bash
+python app.py
+```
+
+This will start the server and provide a link to access the demo in your browser.
+
+## Usage
+
+1. Upload a query image (the image you want to segment)
+2. Upload a support image (an example image with a similar object)
+3. Upload a support mask (the segmentation mask for the support image)
+4. Configure the model parameters using the checkboxes
+5. Click "Run Inference" to generate the segmentation result
+
+## Model Configuration
+
+The app allows you to configure several model parameters:
+- Use Bounding Box: Enable/disable bounding box input
+- Use Points: Enable/disable point input
+- Use Mask: Enable/disable mask input
+- Use CCA: Enable/disable Connected Component Analysis
+- Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement
+
+## Notes
+
+- This demo requires a GPU with CUDA support
+- Large images may require more GPU memory
+- For optimal results, use high-quality support images and masks
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc11ba7160f0fe2dfa97b7c28bbb81683e686e6c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,247 @@
+import os
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import gradio as gr
+from PIL import Image
+import torchvision.transforms as transforms
+from models.ProtoSAM import ProtoSAM, ALPNetWrapper, InputFactory, TYPE_ALPNET
+from models.grid_proto_fewshot import FewShotSeg
+from models.segment_anything.utils.transforms import ResizeLongestSide
+
+# Set environment variables for model caching
+os.environ['TORCH_HOME'] = "./pretrained_model"
+
+# Function to load the model
+def load_model(config):
+ # Initial segmentation model
+ alpnet = FewShotSeg(
+ config["input_size"][0],
+ config["reload_model_path"],
+ config["model"]
+ )
+ alpnet.cuda()
+ base_model = ALPNetWrapper(alpnet)
+
+ # ProtoSAM model
+ sam_checkpoint = "pretrained_model/sam_vit_h.pth"
+ model = ProtoSAM(
+ image_size=(1024, 1024),
+ coarse_segmentation_model=base_model,
+ use_bbox=config["use_bbox"],
+ use_points=config["use_points"],
+ use_mask=config["use_mask"],
+ debug=False,
+ num_points_for_sam=1,
+ use_cca=config["do_cca"],
+ point_mode=config["point_mode"],
+ use_sam_trans=True,
+ coarse_pred_only=config["coarse_pred_only"],
+ sam_pretrained_path=sam_checkpoint,
+ use_neg_points=config["use_neg_points"],
+ )
+ model = model.to(torch.device("cuda"))
+ model.eval()
+ return model
+
+# Function to preprocess images
+def preprocess_image(image, transform):
+ if isinstance(image, np.ndarray):
+ image_np = image
+ else:
+ # Convert PIL Image to numpy array
+ image_np = np.array(image)
+
+ # Convert to RGB if grayscale
+ if len(image_np.shape) == 2:
+ image_np = np.stack([image_np] * 3, axis=2)
+ elif image_np.shape[2] == 1:
+ image_np = np.concatenate([image_np] * 3, axis=2)
+
+ # Apply transforms
+ image_tensor = transform(image_np).unsqueeze(0)
+ return image_tensor
+
+# Function to create overlay visualization
+def create_overlay(query_image, prediction, colormap='YlOrRd'):
+ """
+ Create an overlay of the prediction on the query image
+ """
+ # Convert tensors to numpy arrays for visualization
+ if isinstance(query_image, torch.Tensor):
+ query_image = query_image.cpu().squeeze().numpy()
+
+ if isinstance(prediction, torch.Tensor):
+ prediction = prediction.cpu().squeeze().numpy()
+
+ # Normalize image for visualization
+ query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8)
+
+ # Ensure binary mask
+ prediction = (prediction > 0).astype(np.float32)
+
+ # Create mask overlay
+ mask_cmap = plt.cm.get_cmap(colormap)
+ pred_rgba = mask_cmap(prediction)
+ pred_rgba[..., 3] = prediction * 0.7 # Set alpha channel
+
+ # Create matplotlib figure for overlay
+ fig, ax = plt.subplots(figsize=(10, 10))
+
+ # Handle grayscale vs RGB images
+ if len(query_image.shape) == 2:
+ ax.imshow(query_image, cmap='gray')
+ else:
+ if query_image.shape[0] == 3: # Channel-first format
+ query_image = np.transpose(query_image, (1, 2, 0))
+ ax.imshow(query_image)
+
+ ax.imshow(pred_rgba)
+ ax.axis('off')
+ plt.tight_layout()
+
+ # Convert to PIL Image
+ fig.canvas.draw()
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close(fig)
+
+ return img
+
+# Model configuration
+config = {
+ "input_size": [224],
+ "reload_model_path": "path/to/your/model.pth", # Update with your model path
+ "model": {"encoder": "resnet50", "decoder": "pspnet"},
+ "use_bbox": True,
+ "use_points": True,
+ "use_mask": True,
+ "do_cca": True,
+ "point_mode": "extreme",
+ "coarse_pred_only": False,
+ "use_neg_points": False,
+ "base_model": TYPE_ALPNET
+}
+
+# Function to run inference
+def run_inference(query_image, support_image, support_mask, use_bbox, use_points, use_mask, use_cca, coarse_pred_only):
+ try:
+ # Update config based on user selections
+ config["use_bbox"] = use_bbox
+ config["use_points"] = use_points
+ config["use_mask"] = use_mask
+ config["do_cca"] = use_cca
+ config["coarse_pred_only"] = coarse_pred_only
+
+ # Check if CUDA is available
+ if not torch.cuda.is_available():
+ return None, "CUDA is not available. This demo requires GPU support."
+
+ # Load the model
+ model = load_model(config)
+
+ # Preprocess images
+ sam_trans = ResizeLongestSide(1024)
+
+ # Transform for images
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Resize((1024, 1024), antialias=True),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+
+ # Process query image
+ query_img_tensor = preprocess_image(query_image, transform)
+
+ # Process support image
+ support_img_tensor = preprocess_image(support_image, transform)
+
+ # Process support mask (should be binary)
+ support_mask_np = np.array(support_mask)
+ support_mask_np = (support_mask_np > 127).astype(np.float32) # Binarize mask
+ support_mask_tensor = torch.from_numpy(support_mask_np).unsqueeze(0).unsqueeze(0)
+ support_mask_tensor = torch.nn.functional.interpolate(
+ support_mask_tensor, size=(1024, 1024), mode='nearest'
+ )
+
+ # Prepare model inputs
+ support_images = [support_img_tensor.cuda()]
+ support_masks = [support_mask_tensor.cuda()]
+
+ # Create model input
+ coarse_model_input = InputFactory.create_input(
+ input_type=config["base_model"],
+ query_image=query_img_tensor.cuda(),
+ support_images=support_images,
+ support_labels=support_masks,
+ isval=True,
+ val_wsize=3,
+ original_sz=query_img_tensor.shape[-2:],
+ img_sz=query_img_tensor.shape[-2:],
+ gts=None,
+ )
+ coarse_model_input.to(torch.device("cuda"))
+
+ # Run inference
+ with torch.no_grad():
+ query_pred, scores = model(
+ query_img_tensor.cuda(), coarse_model_input, degrees_rotate=0
+ )
+
+ # Create overlay visualization
+ result_image = create_overlay(query_img_tensor, query_pred)
+
+ confidence_score = np.mean(scores)
+ return result_image, f"Confidence Score: {confidence_score:.4f}"
+
+ except Exception as e:
+ return None, f"Error during inference: {str(e)}"
+
+# Define the Gradio interface
+def create_interface():
+ with gr.Blocks(title="ProtoSAM Segmentation Demo") as demo:
+ gr.Markdown("# ProtoSAM Segmentation Demo")
+ gr.Markdown("Upload a query image, support image, and support mask to generate a segmentation prediction.")
+
+ with gr.Row():
+ with gr.Column():
+ query_image = gr.Image(label="Query Image", type="pil")
+ support_image = gr.Image(label="Support Image", type="pil")
+ support_mask = gr.Image(label="Support Mask", type="pil")
+
+ with gr.Column():
+ result_image = gr.Image(label="Prediction Result")
+ result_text = gr.Textbox(label="Result Information")
+
+ with gr.Row():
+ with gr.Column():
+ use_bbox = gr.Checkbox(label="Use Bounding Box", value=True)
+ use_points = gr.Checkbox(label="Use Points", value=True)
+ use_mask = gr.Checkbox(label="Use Mask", value=True)
+
+ with gr.Column():
+ use_cca = gr.Checkbox(label="Use CCA", value=True)
+ coarse_pred_only = gr.Checkbox(label="Coarse Prediction Only", value=False)
+ run_button = gr.Button("Run Inference")
+
+ run_button.click(
+ fn=run_inference,
+ inputs=[
+ query_image,
+ support_image,
+ support_mask,
+ use_bbox,
+ use_points,
+ use_mask,
+ use_cca,
+ coarse_pred_only
+ ],
+ outputs=[result_image, result_text]
+ )
+
+ return demo
+
+# Create and launch the interface
+if __name__ == "__main__":
+ demo = create_interface()
+ demo.launch(share=True)
\ No newline at end of file
diff --git a/backbone.sh b/backbone.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4319e8f737767d046539f3d0fe94ff704264ad68
--- /dev/null
+++ b/backbone.sh
@@ -0,0 +1,179 @@
+#!/bin/bash
+set -e
+GPUID1=0
+export CUDA_VISIBLE_DEVICES=$GPUID1
+
+MODE=$1
+if [ $MODE != "validation" ] && [ $MODE != "training" ]
+then
+ echo "mode must be either validation or training"
+ exit 1
+fi
+
+# get modality as arg
+MODALITY=$2
+# make sure modality is either ct or mri
+if [ $MODALITY != "ct" ] && [ $MODALITY != "mri" ]
+then
+ echo "modality must be either ct or mri"
+ exit 1
+fi
+
+####### Shared configs ######
+PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training
+INPUT_SIZE=256
+ALL_EV=( 0 ) # 5-fold cross validation (0, 1, 2, 3, 4)
+if [ $MODALITY == "ct" ]
+then
+ DATASET='SABS_Superpix'
+else
+ DATASET='CHAOST2_Superpix'
+fi
+
+if [ $INPUT_SIZE -gt 256 ]
+then
+ DATASET=${DATASET}'_672'
+fi
+
+NWORKER=4
+MODEL_NAME='dinov2_l14'
+LORA=0
+RELOAD_PATH=( "None" )
+SKIP_SLICES="True"
+DO_CCA="True"
+TTT="False"
+NSTEP=100000
+RESET_AFTER_SLICE="True"
+FINETUNE_ON_SUPPORT="False"
+USE_SLICE_ADAPTER="False"
+ADAPTER_LAYERS=1
+CLAHE=False
+ALL_SCALE=( "MIDDLE") # config of pseudolabels
+
+LABEL_SETS=$3
+EXCLU='[2,3]'
+
+if [[ $MODALITY == "mri" && $LABEL_SETS -eq 1 ]]
+then
+ echo "exluding 1, 4"
+ EXCLU='[1,4]' # liver(1), spleen(4)
+fi
+
+ORGANS='kidneys'
+if [ $LABEL_SETS -eq 1 ]
+then
+ ORGANS='liver_spleen'
+fi
+
+
+FREE_DESC=""
+CPT="${MODE}_${MODEL_NAME}_${MODALITY}"
+if [ -n "$FREE_DESC" ]
+then
+ CPT="${CPT}_${FREE_DESC}"
+fi
+
+if [[ $TTT == "True" ]]
+then
+ CPT="${CPT}_ttt_nstep_${NSTEP}"
+ if [ $RESET_AFTER_SLICE == "True" ]
+ then
+ CPT="${CPT}_reset_after_slice"
+ fi
+fi
+
+if [ $USE_SLICE_ADAPTER == "True" ]
+then
+ CPT="${CPT}_w_adapter_${ADAPTER_LAYERS}_layers"
+fi
+
+if [ $LORA -ne 0 ]
+then
+ CPT="${CPT}_lora_${LORA}"
+fi
+
+if [ $CLAHE == "True" ]
+then
+ CPT="${CPT}_w_clahe"
+fi
+
+if [ $DO_CCA = "True" ]
+then
+ CPT="${CPT}_cca"
+fi
+
+CPT="${CPT}_grid_${PROTO_GRID}_res_${INPUT_SIZE}"
+
+if [ ${EXCLU} = "[]" ]
+then
+ CPT="${CPT}_setting1"
+else
+ CPT="${CPT}_setting2"
+fi
+
+CPT="${CPT}_${ORGANS}_fold"
+
+###### Training configs (irrelavent in testing) ######
+DECAY=0.95
+
+MAX_ITER=1000 # defines the size of an epoch
+SNAPSHOT_INTERVAL=25000 # interval for saving snapshot
+SEED='1234'
+
+###### Validation configs ######
+SUPP_ID='[6]' # using the additionally loaded scan as support
+if [ $MODALITY == "mri" ]
+then
+ SUPP_ID='[4]'
+fi
+
+echo ===================================
+
+for ((i=0; i<${#ALL_EV[@]}; i++))
+do
+ EVAL_FOLD=${ALL_EV[i]}
+ CPT_W_FOLD="${CPT}_${EVAL_FOLD}"
+ echo $CPT_W_FOLD on GPU $GPUID1
+ for SUPERPIX_SCALE in "${ALL_SCALE[@]}"
+ do
+ PREFIX="test_vfold${EVAL_FOLD}"
+ echo $PREFIX
+ LOGDIR="./test_${MODALITY}/${CPT_W_FOLD}"
+
+ if [ ! -d $LOGDIR ]
+ then
+ mkdir -p $LOGDIR
+ fi
+
+python3 $MODE.py with \
+ "modelname=$MODEL_NAME" \
+ 'usealign=True' \
+ 'optim_type=sgd' \
+ reload_model_path=${RELOAD_PATH[i]} \
+ num_workers=$NWORKER \
+ scan_per_load=-1 \
+ label_sets=$LABEL_SETS \
+ 'use_wce=True' \
+ exp_prefix=$PREFIX \
+ 'clsname=grid_proto' \
+ n_steps=$NSTEP \
+ exclude_cls_list=$EXCLU \
+ eval_fold=$EVAL_FOLD \
+ dataset=$DATASET \
+ proto_grid_size=$PROTO_GRID \
+ max_iters_per_load=$MAX_ITER \
+ min_fg_data=1 seed=$SEED \
+ save_snapshot_every=$SNAPSHOT_INTERVAL \
+ superpix_scale=$SUPERPIX_SCALE \
+ lr_step_gamma=$DECAY \
+ path.log_dir=$LOGDIR \
+ support_idx=$SUPP_ID \
+ lora=$LORA \
+ do_cca=$DO_CCA \
+ ttt=$TTT \
+ adapter_layers=$ADAPTER_LAYERS \
+ use_slice_adapter=$USE_SLICE_ADAPTER \
+ reset_after_slice=$RESET_AFTER_SLICE \
+ "input_size=($INPUT_SIZE, $INPUT_SIZE)"
+ done
+done
\ No newline at end of file
diff --git a/config_ssl_upload.py b/config_ssl_upload.py
new file mode 100644
index 0000000000000000000000000000000000000000..fba41dd9acfc415c6256dcd1ef2e5edb6eff4d1a
--- /dev/null
+++ b/config_ssl_upload.py
@@ -0,0 +1,177 @@
+"""
+Experiment configuration file
+Extended from config file from original PANet Repository
+"""
+import os
+import re
+import glob
+import itertools
+
+import sacred
+from sacred import Experiment
+from sacred.observers import FileStorageObserver
+from sacred.utils import apply_backspaces_and_linefeeds
+
+from platform import node
+from datetime import datetime
+
+from util.consts import IMG_SIZE
+
+sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
+sacred.SETTINGS.CAPTURE_MODE = 'no'
+
+ex = Experiment('mySSL')
+ex.captured_out_filter = apply_backspaces_and_linefeeds
+
+source_folders = ['.', './dataloaders', './models', './util']
+sources_to_save = list(itertools.chain.from_iterable(
+ [glob.glob(f'{folder}/*.py') for folder in source_folders]))
+for source_file in sources_to_save:
+ ex.add_source_file(source_file)
+
+@ex.config
+def cfg():
+ """Default configurations"""
+ seed = 1234
+ gpu_id = 0
+ mode = 'train' # for now only allows 'train'
+ do_validation=False
+ num_workers = 4 # 0 for debugging.
+
+ dataset = 'CHAOST2' # i.e. abdominal MRI
+ use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images
+
+ ### Training
+ n_steps = 100100
+ batch_size = 1
+ lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)]
+ lr_step_gamma = 0.95
+ ignore_label = 255
+ print_interval = 100
+ save_snapshot_every = 25000
+ max_iters_per_load = 1000 # epoch size, interval for reloading the dataset
+ epochs=1
+ scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory
+ which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms
+ input_size = (IMG_SIZE, IMG_SIZE)
+ min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process
+ label_sets = 0 # which group of labels taking as training (the rest are for testing)
+ curr_cls = "" # choose between rk, lk, spleen and liver
+ exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1
+ usealign = True # see vanilla PANet
+ use_wce = True
+ use_dinov2_loss = False
+ dice_loss = False
+ ### Validation
+ z_margin = 0
+ eval_fold = 0 # which fold for 5 fold cross validation
+ support_idx=[-1] # indicating which scan is used as support in testing.
+ val_wsize=2 # L_H, L_W in testing
+ n_sup_part = 3 # number of chuncks in testing
+ use_clahe = False
+ use_slice_adapter = False
+ adapter_layers=3
+ debug=True
+ skip_no_organ_slices=True
+ # Network
+ modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab
+ clsname = None #
+ reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization)
+ proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training
+ feature_hw = [input_size[0]//8, input_size[0]//8] # feature map size, should couple this with backbone in future
+ lora = 0
+ use_3_slices=False
+ do_cca=False
+ use_edge_detector=False
+ finetune_on_support=False
+ sliding_window_confidence_segmentation=False
+ finetune_model_on_single_slice=False
+ online_finetuning=True
+
+ use_bbox=True # for SAM
+ use_points=True # for SAM
+ use_mask=False # for SAM
+ base_model="alpnet" # or "SAM"
+ # SSL
+ superpix_scale = 'MIDDLE' #MIDDLE/ LARGE
+ use_pos_enc=False
+ support_txt_file = None # path to a txt file containing support slices
+ augment_support_set=False
+ coarse_pred_only=False # for ProtoSAM
+ point_mode="both" # for ProtoSAM, choose: both, conf, centroid
+ use_neg_points=False
+ n_support=1 # num support images
+ protosam_sam_ver="sam_h" # or medsam
+ grad_accumulation_steps=1
+ ttt=False
+ reset_after_slice=True # for TTT, if to reset the model after finetuning on each slice
+ model = {
+ 'align': usealign,
+ 'dinov2_loss': use_dinov2_loss,
+ 'use_coco_init': use_coco_init,
+ 'which_model': modelname,
+ 'cls_name': clsname,
+ 'proto_grid_size' : proto_grid_size,
+ 'feature_hw': feature_hw,
+ 'reload_model_path': reload_model_path,
+ 'lora': lora,
+ 'use_slice_adapter': use_slice_adapter,
+ 'adapter_layers': adapter_layers,
+ 'debug': debug,
+ 'use_pos_enc': use_pos_enc
+ }
+
+ task = {
+ 'n_ways': 1,
+ 'n_shots': 1,
+ 'n_queries': 1,
+ 'npart': n_sup_part
+ }
+
+ optim_type = 'sgd'
+ lr=1e-3
+ momentum=0.9
+ weight_decay=0.0005
+ optim = {
+ 'lr': lr,
+ 'momentum': momentum,
+ 'weight_decay': weight_decay
+ }
+
+ exp_prefix = ''
+
+ exp_str = '_'.join(
+ [exp_prefix]
+ + [dataset,]
+ + [f'sets_{label_sets}_{task["n_shots"]}shot'])
+
+ path = {
+ 'log_dir': './runs',
+ 'SABS':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"
+ },
+ 'SABS_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"
+ },
+ 'SABS_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"
+ },
+ 'C0':{'data_dir': "feed your dataset path here"
+ },
+ 'CHAOST2':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"
+ },
+ 'CHAOST2_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"
+ },
+ 'SABS_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"},
+ 'C0_Superpix':{'data_dir': "feed your dataset path here"},
+ 'CHAOST2_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"},
+ 'CHAOST2_Superpix_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"},
+ 'SABS_Superpix_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"},
+ 'SABS_Superpix_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"},
+ }
+
+
+@ex.config_hook
+def add_observer(config, command_name, logger):
+ """A hook fucntion to add observer"""
+ exp_name = f'{ex.path}_{config["exp_str"]}'
+ observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
+ ex.observers.append(observer)
+ return config
diff --git a/data/data_processing.ipynb b/data/data_processing.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..76cdd5bea9eecf70b9d10838a9a19b8bb571de59
--- /dev/null
+++ b/data/data_processing.ipynb
@@ -0,0 +1,2687 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For convinience, I've unified all of the data preprocessing\n",
+ "notebooks from [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation.git) into a single notebook"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%reset\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import glob\n",
+ "import SimpleITK as sitk\n",
+ "import sys\n",
+ "\n",
+ "sys.path.insert(0, '../')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create dirs for the SABS and CHAOS datasets\n",
+ "os.makedirs('./SABS', exist_ok=True)\n",
+ "os.makedirs('./CHAOST2', exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def copy_spacing_ori(src, dst):\n",
+ " dst.SetSpacing(src.GetSpacing())\n",
+ " dst.SetOrigin(src.GetOrigin())\n",
+ " dst.SetDirection(src.GetDirection())\n",
+ " return dst\n",
+ "\n",
+ "# helper functions copy pasted\n",
+ "def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):\n",
+ " resample = sitk.ResampleImageFilter()\n",
+ " resample.SetInterpolator(interpolator)\n",
+ " resample.SetOutputDirection(mov_img_obj.GetDirection())\n",
+ " resample.SetOutputOrigin(mov_img_obj.GetOrigin())\n",
+ " mov_spacing = mov_img_obj.GetSpacing()\n",
+ "\n",
+ " resample.SetOutputSpacing(new_spacing)\n",
+ " RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)\n",
+ " new_size = np.array(mov_img_obj.GetSize()) * RES_COE \n",
+ "\n",
+ " resample.SetSize( [int(sz+1) for sz in new_size] )\n",
+ " if logging:\n",
+ " print(\"Spacing: {} -> {}\".format(mov_spacing, new_spacing))\n",
+ " print(\"Size {} -> {}\".format( mov_img_obj.GetSize(), new_size ))\n",
+ "\n",
+ " return resample.Execute(mov_img_obj)\n",
+ "\n",
+ "def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):\n",
+ " src_mat = sitk.GetArrayFromImage(mov_lb_obj)\n",
+ " lbvs = np.unique(src_mat)\n",
+ " if logging:\n",
+ " print(\"Label values: {}\".format(lbvs))\n",
+ " for idx, lbv in enumerate(lbvs):\n",
+ " _src_curr_mat = np.float32(src_mat == lbv) \n",
+ " _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)\n",
+ " _src_curr_obj.CopyInformation(mov_lb_obj)\n",
+ " _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )\n",
+ " _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv\n",
+ " if idx == 0:\n",
+ " out_vol = _tar_curr_mat\n",
+ " else:\n",
+ " out_vol[_tar_curr_mat == lbv] = lbv\n",
+ " out_obj = sitk.GetImageFromArray(out_vol)\n",
+ " out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )\n",
+ " if ref_img != None:\n",
+ " out_obj.CopyInformation(ref_img)\n",
+ " return out_obj\n",
+ " \n",
+ "## Then crop ROI\n",
+ "def get_label_center(label):\n",
+ " nnz = np.sum(label > 1e-5)\n",
+ " return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))\n",
+ "\n",
+ "def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):\n",
+ " \"\"\" crop a 3d matrix given the index of the new volume on the original volume\n",
+ " Args:\n",
+ " refernce_ctr_idx: the center of the new volume on the original volume (in indices)\n",
+ " only_2d: only do cropping on first two dimensions\n",
+ " \"\"\"\n",
+ " _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case\n",
+ " if only_2d:\n",
+ " assert len(crop_size) == 2, \"Actual len {}\".format(len(crop_size))\n",
+ " assert len(referece_ctr_idx) == 2, \"Actual len {}\".format(len(referece_ctr_idx))\n",
+ " _expand_cropsize.append(ori_vol.shape[-1])\n",
+ " \n",
+ " image_patch = np.ones(tuple(_expand_cropsize)) * padval\n",
+ "\n",
+ " half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )\n",
+ " _min_idx = [0,0,0]\n",
+ " _max_idx = list(ori_vol.shape)\n",
+ "\n",
+ " # bias of actual cropped size to the beginning and the end of this volume\n",
+ " _bias_start = [0,0,0]\n",
+ " _bias_end = [0,0,0]\n",
+ "\n",
+ " for dim,hsize in enumerate(half_size):\n",
+ " if dim == 2 and only_2d:\n",
+ " break\n",
+ "\n",
+ " _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])\n",
+ " _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])\n",
+ "\n",
+ " _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]\n",
+ " _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]\n",
+ " \n",
+ " if only_2d:\n",
+ " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n",
+ " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \\\n",
+ " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n",
+ " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]\n",
+ "\n",
+ " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]\n",
+ " # then goes back to original volume\n",
+ " else:\n",
+ " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n",
+ " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \\\n",
+ " half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \\\n",
+ " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n",
+ " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \\\n",
+ " referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]\n",
+ "\n",
+ " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]\n",
+ " return image_patch\n",
+ "\n",
+ "s2n = sitk.GetArrayFromImage\n",
+ "\n",
+ "\n",
+ "def resample_imgs(imgs, segs, pids, scan_dir, BD_BIAS, SPA_FAC, required_res=512):\n",
+ " spa_fac = SPA_FAC\n",
+ " for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n",
+ "\n",
+ " # lb_n = nio.read_nii_bysitk(seg_fid)\n",
+ "\n",
+ " img_obj = sitk.ReadImage( img_fid )\n",
+ " seg_obj = sitk.ReadImage( seg_fid )\n",
+ " print(img_fid, seg_fid)\n",
+ " ## image\n",
+ " array = sitk.GetArrayFromImage(img_obj)\n",
+ " H = W = array.shape[-1]\n",
+ " if SPA_FAC is None:\n",
+ " spa_fac = (H - 2 * BD_BIAS) / required_res\n",
+ " print(array.shape, f\"label shape {sitk.GetArrayFromImage(seg_obj).shape}\")\n",
+ " # cropping\n",
+ " array = array[:, BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS]\n",
+ " cropped_img_o = sitk.GetImageFromArray(array)\n",
+ " cropped_img_o = copy_spacing_ori(img_obj, cropped_img_o)\n",
+ "\n",
+ " # resampling\n",
+ " img_spa_ori = img_obj.GetSpacing()\n",
+ " res_img_o = resample_by_res(cropped_img_o, [img_spa_ori[0] * spa_fac, img_spa_ori[1] * spa_fac, img_spa_ori[-1]], interpolator = sitk.sitkLinear,\n",
+ " logging = True)\n",
+ "\n",
+ " ## label\n",
+ " lb_arr = sitk.GetArrayFromImage(seg_obj)\n",
+ " # cropping\n",
+ " lb_arr = lb_arr[:,BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS]\n",
+ " cropped_lb_o = sitk.GetImageFromArray(lb_arr)\n",
+ " cropped_lb_o = copy_spacing_ori(seg_obj, cropped_lb_o)\n",
+ "\n",
+ " lb_spa_ori = seg_obj.GetSpacing()\n",
+ "\n",
+ " # resampling\n",
+ " res_lb_o = resample_lb_by_res(cropped_lb_o, [lb_spa_ori[0] * spa_fac, lb_spa_ori[1] * spa_fac, lb_spa_ori[-1] ], interpolator = sitk.sitkLinear,\n",
+ " ref_img = res_img_o, logging = True)\n",
+ "\n",
+ " \n",
+ " out_img_fid = os.path.join( scan_dir, f'image_{pid}.nii.gz' )\n",
+ " out_lb_fid = os.path.join( scan_dir, f'label_{pid}.nii.gz' ) \n",
+ " \n",
+ " # then save\n",
+ " sitk.WriteImage(res_img_o, out_img_fid, True) \n",
+ " sitk.WriteImage(res_lb_o, out_lb_fid, True) \n",
+ " print(f\"{out_img_fid} has been saved, shape: {res_img_o.GetSize()}\")\n",
+ " print(f\"{out_lb_fid} has been saved\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Intensitiy Normalization for CT Images"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# set up directories for images\n",
+ "IMG_FOLDER=\"./miccai2015/RawData/Training/img\"\n",
+ "SEG_FOLDER=\"./miccai2015/RawData/Training/label\"\n",
+ "OUT_FOLDER=\"./SABS/tmp_normalized/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[]\n"
+ ]
+ }
+ ],
+ "source": [
+ "imgs = sorted(glob.glob(IMG_FOLDER + \"/*.nii.gz\"))\n",
+ "segs = sorted(glob.glob(SEG_FOLDER + \"/*.nii.gz\"))\n",
+ "pids = [pid.split(\"img\")[-1].split(\".\")[0] for pid in imgs]\n",
+ "print(sorted(pids))\n",
+ "assert len(imgs) == len(segs)\n",
+ "for img, seg in zip(imgs, segs):\n",
+ " print(img, seg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(147, 512, 512) label shape (147, 512, 512)\n",
+ "./SABS/tmp_normalized/image_0.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_0.nii.gz has been save\n",
+ "(139, 512, 512) label shape (139, 512, 512)\n",
+ "./SABS/tmp_normalized/image_1.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_1.nii.gz has been save\n",
+ "(198, 512, 512) label shape (198, 512, 512)\n",
+ "./SABS/tmp_normalized/image_2.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_2.nii.gz has been save\n",
+ "(140, 512, 512) label shape (140, 512, 512)\n",
+ "./SABS/tmp_normalized/image_3.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_3.nii.gz has been save\n",
+ "(117, 512, 512) label shape (117, 512, 512)\n",
+ "./SABS/tmp_normalized/image_4.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_4.nii.gz has been save\n",
+ "(131, 512, 512) label shape (131, 512, 512)\n",
+ "./SABS/tmp_normalized/image_5.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_5.nii.gz has been save\n",
+ "(163, 512, 512) label shape (163, 512, 512)\n",
+ "./SABS/tmp_normalized/image_6.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_6.nii.gz has been save\n",
+ "(148, 512, 512) label shape (148, 512, 512)\n",
+ "./SABS/tmp_normalized/image_7.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_7.nii.gz has been save\n",
+ "(149, 512, 512) label shape (149, 512, 512)\n",
+ "./SABS/tmp_normalized/image_8.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_8.nii.gz has been save\n",
+ "(148, 512, 512) label shape (148, 512, 512)\n",
+ "./SABS/tmp_normalized/image_9.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_9.nii.gz has been save\n",
+ "(143, 512, 512) label shape (143, 512, 512)\n",
+ "./SABS/tmp_normalized/image_10.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_10.nii.gz has been save\n",
+ "(89, 512, 512) label shape (89, 512, 512)\n",
+ "./SABS/tmp_normalized/image_11.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_11.nii.gz has been save\n",
+ "(96, 512, 512) label shape (96, 512, 512)\n",
+ "./SABS/tmp_normalized/image_12.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_12.nii.gz has been save\n",
+ "(124, 512, 512) label shape (124, 512, 512)\n",
+ "./SABS/tmp_normalized/image_13.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_13.nii.gz has been save\n",
+ "(85, 512, 512) label shape (85, 512, 512)\n",
+ "./SABS/tmp_normalized/image_14.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_14.nii.gz has been save\n",
+ "(131, 512, 512) label shape (131, 512, 512)\n",
+ "./SABS/tmp_normalized/image_15.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_15.nii.gz has been save\n",
+ "(88, 512, 512) label shape (88, 512, 512)\n",
+ "./SABS/tmp_normalized/image_16.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_16.nii.gz has been save\n",
+ "(89, 512, 512) label shape (89, 512, 512)\n",
+ "./SABS/tmp_normalized/image_17.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_17.nii.gz has been save\n",
+ "(100, 512, 512) label shape (100, 512, 512)\n",
+ "./SABS/tmp_normalized/image_18.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_18.nii.gz has been save\n",
+ "(153, 512, 512) label shape (153, 512, 512)\n",
+ "./SABS/tmp_normalized/image_19.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_19.nii.gz has been save\n",
+ "(93, 512, 512) label shape (93, 512, 512)\n",
+ "./SABS/tmp_normalized/image_20.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_20.nii.gz has been save\n",
+ "(144, 512, 512) label shape (144, 512, 512)\n",
+ "./SABS/tmp_normalized/image_21.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_21.nii.gz has been save\n",
+ "(104, 512, 512) label shape (104, 512, 512)\n",
+ "./SABS/tmp_normalized/image_22.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_22.nii.gz has been save\n",
+ "(98, 512, 512) label shape (98, 512, 512)\n",
+ "./SABS/tmp_normalized/image_23.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_23.nii.gz has been save\n",
+ "(94, 512, 512) label shape (94, 512, 512)\n",
+ "./SABS/tmp_normalized/image_24.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_24.nii.gz has been save\n",
+ "(184, 512, 512) label shape (184, 512, 512)\n",
+ "./SABS/tmp_normalized/image_25.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_25.nii.gz has been save\n",
+ "(99, 512, 512) label shape (99, 512, 512)\n",
+ "./SABS/tmp_normalized/image_26.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_26.nii.gz has been save\n",
+ "(100, 512, 512) label shape (100, 512, 512)\n",
+ "./SABS/tmp_normalized/image_27.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_27.nii.gz has been save\n",
+ "(90, 512, 512) label shape (90, 512, 512)\n",
+ "./SABS/tmp_normalized/image_28.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_28.nii.gz has been save\n",
+ "(195, 512, 512) label shape (195, 512, 512)\n",
+ "./SABS/tmp_normalized/image_29.nii.gz has been save\n",
+ "./SABS/tmp_normalized/label_29.nii.gz has been save\n"
+ ]
+ }
+ ],
+ "source": [
+ "import copy\n",
+ "scan_dir = OUT_FOLDER\n",
+ "LIR = -125\n",
+ "HIR = 275\n",
+ "os.makedirs(scan_dir, exist_ok = True)\n",
+ "\n",
+ "reindex = 0\n",
+ "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n",
+ "\n",
+ " img_obj = sitk.ReadImage( img_fid )\n",
+ " seg_obj = sitk.ReadImage( seg_fid )\n",
+ "\n",
+ " array = sitk.GetArrayFromImage(img_obj)\n",
+ " print(array.shape, f\"label shape {sitk.GetArrayFromImage(seg_obj).shape}\")\n",
+ " array[array > HIR] = HIR\n",
+ " array[array < LIR] = LIR\n",
+ " \n",
+ " array = (array - array.min()) / (array.max() - array.min()) * 255.0\n",
+ " \n",
+ " # then normalize this\n",
+ " \n",
+ " wined_img = sitk.GetImageFromArray(array)\n",
+ " wined_img = copy_spacing_ori(img_obj, wined_img)\n",
+ " \n",
+ " out_img_fid = os.path.join( scan_dir, f'image_{str(reindex)}.nii.gz' )\n",
+ " out_lb_fid = os.path.join( scan_dir, f'label_{str(reindex)}.nii.gz' ) \n",
+ " \n",
+ " # then save\n",
+ " sitk.WriteImage(wined_img, out_img_fid, True) \n",
+ " sitk.WriteImage(seg_obj, out_lb_fid, True) \n",
+ " print(\"{} has been save\".format(out_img_fid))\n",
+ " print(\"{} has been save\".format(out_lb_fid))\n",
+ " reindex += 1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Overview\n",
+ "\n",
+ "This is the second step of preprocessing\n",
+ "\n",
+ "Cut out irrelevant empty boundary and resample to 512x512 in axial plane.\n",
+ "\n",
+ "Input: intensity-normalized images\n",
+ "\n",
+ "Output: spacially resampled images"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['0', '1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '4', '5', '6', '7', '8', '9']\n",
+ "./SABS/tmp_normalized/image_0.nii.gz ./SABS/tmp_normalized/label_0.nii.gz\n",
+ "./SABS/tmp_normalized/image_1.nii.gz ./SABS/tmp_normalized/label_1.nii.gz\n",
+ "./SABS/tmp_normalized/image_10.nii.gz ./SABS/tmp_normalized/label_10.nii.gz\n",
+ "./SABS/tmp_normalized/image_11.nii.gz ./SABS/tmp_normalized/label_11.nii.gz\n",
+ "./SABS/tmp_normalized/image_12.nii.gz ./SABS/tmp_normalized/label_12.nii.gz\n",
+ "./SABS/tmp_normalized/image_13.nii.gz ./SABS/tmp_normalized/label_13.nii.gz\n",
+ "./SABS/tmp_normalized/image_14.nii.gz ./SABS/tmp_normalized/label_14.nii.gz\n",
+ "./SABS/tmp_normalized/image_15.nii.gz ./SABS/tmp_normalized/label_15.nii.gz\n",
+ "./SABS/tmp_normalized/image_16.nii.gz ./SABS/tmp_normalized/label_16.nii.gz\n",
+ "./SABS/tmp_normalized/image_17.nii.gz ./SABS/tmp_normalized/label_17.nii.gz\n",
+ "./SABS/tmp_normalized/image_18.nii.gz ./SABS/tmp_normalized/label_18.nii.gz\n",
+ "./SABS/tmp_normalized/image_19.nii.gz ./SABS/tmp_normalized/label_19.nii.gz\n",
+ "./SABS/tmp_normalized/image_2.nii.gz ./SABS/tmp_normalized/label_2.nii.gz\n",
+ "./SABS/tmp_normalized/image_20.nii.gz ./SABS/tmp_normalized/label_20.nii.gz\n",
+ "./SABS/tmp_normalized/image_21.nii.gz ./SABS/tmp_normalized/label_21.nii.gz\n",
+ "./SABS/tmp_normalized/image_22.nii.gz ./SABS/tmp_normalized/label_22.nii.gz\n",
+ "./SABS/tmp_normalized/image_23.nii.gz ./SABS/tmp_normalized/label_23.nii.gz\n",
+ "./SABS/tmp_normalized/image_24.nii.gz ./SABS/tmp_normalized/label_24.nii.gz\n",
+ "./SABS/tmp_normalized/image_25.nii.gz ./SABS/tmp_normalized/label_25.nii.gz\n",
+ "./SABS/tmp_normalized/image_26.nii.gz ./SABS/tmp_normalized/label_26.nii.gz\n",
+ "./SABS/tmp_normalized/image_27.nii.gz ./SABS/tmp_normalized/label_27.nii.gz\n",
+ "./SABS/tmp_normalized/image_28.nii.gz ./SABS/tmp_normalized/label_28.nii.gz\n",
+ "./SABS/tmp_normalized/image_29.nii.gz ./SABS/tmp_normalized/label_29.nii.gz\n",
+ "./SABS/tmp_normalized/image_3.nii.gz ./SABS/tmp_normalized/label_3.nii.gz\n",
+ "./SABS/tmp_normalized/image_4.nii.gz ./SABS/tmp_normalized/label_4.nii.gz\n",
+ "./SABS/tmp_normalized/image_5.nii.gz ./SABS/tmp_normalized/label_5.nii.gz\n",
+ "./SABS/tmp_normalized/image_6.nii.gz ./SABS/tmp_normalized/label_6.nii.gz\n",
+ "./SABS/tmp_normalized/image_7.nii.gz ./SABS/tmp_normalized/label_7.nii.gz\n",
+ "./SABS/tmp_normalized/image_8.nii.gz ./SABS/tmp_normalized/label_8.nii.gz\n",
+ "./SABS/tmp_normalized/image_9.nii.gz ./SABS/tmp_normalized/label_9.nii.gz\n"
+ ]
+ }
+ ],
+ "source": [
+ "IMG_FOLDER = \"./SABS/tmp_normalized\"\n",
+ "\n",
+ "SEG_FOLDER = IMG_FOLDER\n",
+ "imgs = glob.glob(IMG_FOLDER + \"/image_*.nii.gz\")\n",
+ "imgs = sorted([ fid for fid in sorted(imgs) ])\n",
+ "segs = sorted([ fid for fid in glob.glob(SEG_FOLDER + \"/label_*.nii.gz\")])\n",
+ "\n",
+ "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]\n",
+ "print(pids)\n",
+ "for img, seg in zip(imgs, segs):\n",
+ " print(img, seg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "./SABS/tmp_normalized/image_0.nii.gz ./SABS/tmp_normalized/label_0.nii.gz\n",
+ "(147, 512, 512) label shape (147, 512, 512)\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n",
+ "Size (448, 448, 147) -> [448. 448. 147.]\n",
+ "./SABS/sabs_CT_normalized/image_0.nii.gz has been saved, shape: (449, 449, 148)\n",
+ "./SABS/sabs_CT_normalized/label_0.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_1.nii.gz ./SABS/tmp_normalized/label_1.nii.gz\n",
+ "(139, 512, 512) label shape (139, 512, 512)\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n",
+ "Size (448, 448, 139) -> [448. 448. 139.]\n",
+ "./SABS/sabs_CT_normalized/image_1.nii.gz has been saved, shape: (449, 449, 140)\n",
+ "./SABS/sabs_CT_normalized/label_1.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_10.nii.gz ./SABS/tmp_normalized/label_10.nii.gz\n",
+ "(143, 512, 512) label shape (143, 512, 512)\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n",
+ "Size (448, 448, 143) -> [448. 448. 143.]\n",
+ "./SABS/sabs_CT_normalized/image_10.nii.gz has been saved, shape: (449, 449, 144)\n",
+ "./SABS/sabs_CT_normalized/label_10.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_11.nii.gz ./SABS/tmp_normalized/label_11.nii.gz\n",
+ "(89, 512, 512) label shape (89, 512, 512)\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "./SABS/sabs_CT_normalized/image_11.nii.gz has been saved, shape: (449, 449, 90)\n",
+ "./SABS/sabs_CT_normalized/label_11.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_12.nii.gz ./SABS/tmp_normalized/label_12.nii.gz\n",
+ "(96, 512, 512) label shape (96, 512, 512)\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Label values: [ 0 1 2 3 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n",
+ "Size (448, 448, 96) -> [448. 448. 96.]\n",
+ "./SABS/sabs_CT_normalized/image_12.nii.gz has been saved, shape: (449, 449, 97)\n",
+ "./SABS/sabs_CT_normalized/label_12.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_13.nii.gz ./SABS/tmp_normalized/label_13.nii.gz\n",
+ "(124, 512, 512) label shape (124, 512, 512)\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n",
+ "Size (448, 448, 124) -> [448. 448. 124.]\n",
+ "./SABS/sabs_CT_normalized/image_13.nii.gz has been saved, shape: (449, 449, 125)\n",
+ "./SABS/sabs_CT_normalized/label_13.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_14.nii.gz ./SABS/tmp_normalized/label_14.nii.gz\n",
+ "(85, 512, 512) label shape (85, 512, 512)\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n",
+ "Size (448, 448, 85) -> [448. 448. 85.]\n",
+ "./SABS/sabs_CT_normalized/image_14.nii.gz has been saved, shape: (449, 449, 86)\n",
+ "./SABS/sabs_CT_normalized/label_14.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_15.nii.gz ./SABS/tmp_normalized/label_15.nii.gz\n",
+ "(131, 512, 512) label shape (131, 512, 512)\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "./SABS/sabs_CT_normalized/image_15.nii.gz has been saved, shape: (449, 449, 132)\n",
+ "./SABS/sabs_CT_normalized/label_15.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_16.nii.gz ./SABS/tmp_normalized/label_16.nii.gz\n",
+ "(88, 512, 512) label shape (88, 512, 512)\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n",
+ "Size (448, 448, 88) -> [448. 448. 88.]\n",
+ "./SABS/sabs_CT_normalized/image_16.nii.gz has been saved, shape: (449, 449, 89)\n",
+ "./SABS/sabs_CT_normalized/label_16.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_17.nii.gz ./SABS/tmp_normalized/label_17.nii.gz\n",
+ "(89, 512, 512) label shape (89, 512, 512)\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n",
+ "Size (448, 448, 89) -> [448. 448. 89.]\n",
+ "./SABS/sabs_CT_normalized/image_17.nii.gz has been saved, shape: (449, 449, 90)\n",
+ "./SABS/sabs_CT_normalized/label_17.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_18.nii.gz ./SABS/tmp_normalized/label_18.nii.gz\n",
+ "(100, 512, 512) label shape (100, 512, 512)\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "./SABS/sabs_CT_normalized/image_18.nii.gz has been saved, shape: (449, 449, 101)\n",
+ "./SABS/sabs_CT_normalized/label_18.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_19.nii.gz ./SABS/tmp_normalized/label_19.nii.gz\n",
+ "(153, 512, 512) label shape (153, 512, 512)\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 153) -> [448. 448. 153.]\n",
+ "./SABS/sabs_CT_normalized/image_19.nii.gz has been saved, shape: (449, 449, 154)\n",
+ "./SABS/sabs_CT_normalized/label_19.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_2.nii.gz ./SABS/tmp_normalized/label_2.nii.gz\n",
+ "(198, 512, 512) label shape (198, 512, 512)\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n",
+ "Size (448, 448, 198) -> [448. 448. 198.]\n",
+ "./SABS/sabs_CT_normalized/image_2.nii.gz has been saved, shape: (449, 449, 199)\n",
+ "./SABS/sabs_CT_normalized/label_2.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_20.nii.gz ./SABS/tmp_normalized/label_20.nii.gz\n",
+ "(93, 512, 512) label shape (93, 512, 512)\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n",
+ "Size (448, 448, 93) -> [448. 448. 93.]\n",
+ "./SABS/sabs_CT_normalized/image_20.nii.gz has been saved, shape: (449, 449, 94)\n",
+ "./SABS/sabs_CT_normalized/label_20.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_21.nii.gz ./SABS/tmp_normalized/label_21.nii.gz\n",
+ "(144, 512, 512) label shape (144, 512, 512)\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n",
+ "Size (448, 448, 144) -> [448. 448. 144.]\n",
+ "./SABS/sabs_CT_normalized/image_21.nii.gz has been saved, shape: (449, 449, 145)\n",
+ "./SABS/sabs_CT_normalized/label_21.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_22.nii.gz ./SABS/tmp_normalized/label_22.nii.gz\n",
+ "(104, 512, 512) label shape (104, 512, 512)\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n",
+ "Size (448, 448, 104) -> [448. 448. 104.]\n",
+ "./SABS/sabs_CT_normalized/image_22.nii.gz has been saved, shape: (449, 449, 105)\n",
+ "./SABS/sabs_CT_normalized/label_22.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_23.nii.gz ./SABS/tmp_normalized/label_23.nii.gz\n",
+ "(98, 512, 512) label shape (98, 512, 512)\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n",
+ "Size (448, 448, 98) -> [448. 448. 98.]\n",
+ "./SABS/sabs_CT_normalized/image_23.nii.gz has been saved, shape: (449, 449, 99)\n",
+ "./SABS/sabs_CT_normalized/label_23.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_24.nii.gz ./SABS/tmp_normalized/label_24.nii.gz\n",
+ "(94, 512, 512) label shape (94, 512, 512)\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Label values: [ 0 1 2 3 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n",
+ "Size (448, 448, 94) -> [448. 448. 94.]\n",
+ "./SABS/sabs_CT_normalized/image_24.nii.gz has been saved, shape: (449, 449, 95)\n",
+ "./SABS/sabs_CT_normalized/label_24.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_25.nii.gz ./SABS/tmp_normalized/label_25.nii.gz\n",
+ "(184, 512, 512) label shape (184, 512, 512)\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n",
+ "Size (448, 448, 184) -> [448. 448. 184.]\n",
+ "./SABS/sabs_CT_normalized/image_25.nii.gz has been saved, shape: (449, 449, 185)\n",
+ "./SABS/sabs_CT_normalized/label_25.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_26.nii.gz ./SABS/tmp_normalized/label_26.nii.gz\n",
+ "(99, 512, 512) label shape (99, 512, 512)\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n",
+ "Size (448, 448, 99) -> [448. 448. 99.]\n",
+ "./SABS/sabs_CT_normalized/image_26.nii.gz has been saved, shape: (449, 449, 100)\n",
+ "./SABS/sabs_CT_normalized/label_26.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_27.nii.gz ./SABS/tmp_normalized/label_27.nii.gz\n",
+ "(100, 512, 512) label shape (100, 512, 512)\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n",
+ "Size (448, 448, 100) -> [448. 448. 100.]\n",
+ "./SABS/sabs_CT_normalized/image_27.nii.gz has been saved, shape: (449, 449, 101)\n",
+ "./SABS/sabs_CT_normalized/label_27.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_28.nii.gz ./SABS/tmp_normalized/label_28.nii.gz\n",
+ "(90, 512, 512) label shape (90, 512, 512)\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n",
+ "Size (448, 448, 90) -> [448. 448. 90.]\n",
+ "./SABS/sabs_CT_normalized/image_28.nii.gz has been saved, shape: (449, 449, 91)\n",
+ "./SABS/sabs_CT_normalized/label_28.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_29.nii.gz ./SABS/tmp_normalized/label_29.nii.gz\n",
+ "(195, 512, 512) label shape (195, 512, 512)\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n",
+ "Size (448, 448, 195) -> [448. 448. 195.]\n",
+ "./SABS/sabs_CT_normalized/image_29.nii.gz has been saved, shape: (449, 449, 196)\n",
+ "./SABS/sabs_CT_normalized/label_29.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_3.nii.gz ./SABS/tmp_normalized/label_3.nii.gz\n",
+ "(140, 512, 512) label shape (140, 512, 512)\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n",
+ "Size (448, 448, 140) -> [448. 448. 140.]\n",
+ "./SABS/sabs_CT_normalized/image_3.nii.gz has been saved, shape: (449, 449, 141)\n",
+ "./SABS/sabs_CT_normalized/label_3.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_4.nii.gz ./SABS/tmp_normalized/label_4.nii.gz\n",
+ "(117, 512, 512) label shape (117, 512, 512)\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n",
+ "Size (448, 448, 117) -> [448. 448. 117.]\n",
+ "./SABS/sabs_CT_normalized/image_4.nii.gz has been saved, shape: (449, 449, 118)\n",
+ "./SABS/sabs_CT_normalized/label_4.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_5.nii.gz ./SABS/tmp_normalized/label_5.nii.gz\n",
+ "(131, 512, 512) label shape (131, 512, 512)\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n",
+ "Size (448, 448, 131) -> [448. 448. 131.]\n",
+ "./SABS/sabs_CT_normalized/image_5.nii.gz has been saved, shape: (449, 449, 132)\n",
+ "./SABS/sabs_CT_normalized/label_5.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_6.nii.gz ./SABS/tmp_normalized/label_6.nii.gz\n",
+ "(163, 512, 512) label shape (163, 512, 512)\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n",
+ "Size (448, 448, 163) -> [448. 448. 163.]\n",
+ "./SABS/sabs_CT_normalized/image_6.nii.gz has been saved, shape: (449, 449, 164)\n",
+ "./SABS/sabs_CT_normalized/label_6.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_7.nii.gz ./SABS/tmp_normalized/label_7.nii.gz\n",
+ "(148, 512, 512) label shape (148, 512, 512)\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "./SABS/sabs_CT_normalized/image_7.nii.gz has been saved, shape: (449, 449, 149)\n",
+ "./SABS/sabs_CT_normalized/label_7.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_8.nii.gz ./SABS/tmp_normalized/label_8.nii.gz\n",
+ "(149, 512, 512) label shape (149, 512, 512)\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n",
+ "Size (448, 448, 149) -> [448. 448. 149.]\n",
+ "./SABS/sabs_CT_normalized/image_8.nii.gz has been saved, shape: (449, 449, 150)\n",
+ "./SABS/sabs_CT_normalized/label_8.nii.gz has been saved\n",
+ "./SABS/tmp_normalized/image_9.nii.gz ./SABS/tmp_normalized/label_9.nii.gz\n",
+ "(148, 512, 512) label shape (148, 512, 512)\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n",
+ "Size (448, 448, 148) -> [448. 448. 148.]\n",
+ "./SABS/sabs_CT_normalized/image_9.nii.gz has been saved, shape: (449, 449, 149)\n",
+ "./SABS/sabs_CT_normalized/label_9.nii.gz has been saved\n"
+ ]
+ }
+ ],
+ "source": [
+ "import copy\n",
+ "OUT_FOLDER = \"./SABS/sabs_CT_normalized\"\n",
+ "BD_BIAS = 32 # cut irrelavent empty boundary to make roi stands out\n",
+ "\n",
+ "# SPA_FAC = (512 - 2 * BD_BIAS) / 512 # spacing factor\n",
+ "for res in (256, 672):\n",
+ " if res == 672:\n",
+ " OUT_FOLDER += \"_672\"\n",
+ " scan_dir = OUT_FOLDER\n",
+ " os.makedirs(OUT_FOLDER, exist_ok = True)\n",
+ "\n",
+ " resample_imgs(imgs, segs, pids, scan_dir, BD_BIAS, SPA_FAC=None, required_res=res)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Synapse Classmap Generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pid 0 finished!\n",
+ "pid 1 finished!\n",
+ "pid 2 finished!\n",
+ "pid 3 finished!\n",
+ "pid 4 finished!\n",
+ "pid 5 finished!\n",
+ "pid 6 finished!\n",
+ "pid 7 finished!\n",
+ "pid 8 finished!\n",
+ "pid 9 finished!\n",
+ "pid 10 finished!\n",
+ "pid 11 finished!\n",
+ "pid 12 finished!\n",
+ "pid 13 finished!\n",
+ "pid 14 finished!\n",
+ "pid 15 finished!\n",
+ "pid 16 finished!\n",
+ "pid 17 finished!\n",
+ "pid 18 finished!\n",
+ "pid 19 finished!\n",
+ "pid 20 finished!\n",
+ "pid 21 finished!\n",
+ "pid 22 finished!\n",
+ "pid 23 finished!\n",
+ "pid 24 finished!\n",
+ "pid 25 finished!\n",
+ "pid 26 finished!\n",
+ "pid 27 finished!\n",
+ "pid 28 finished!\n",
+ "pid 29 finished!\n",
+ "pid 30 finished!\n",
+ "pid 31 finished!\n",
+ "pid 32 finished!\n",
+ "pid 33 finished!\n",
+ "pid 34 finished!\n",
+ "pid 35 finished!\n",
+ "pid 36 finished!\n",
+ "pid 37 finished!\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "# import niftiio as nio\n",
+ "import SimpleITK as sitk\n",
+ "\n",
+ "# normalization: cut top 2% of histogram, then doing volume-wise normalization\n",
+ "IMG_BNAMES = (\"./SABS/sabs_CT_normalized/image_*.nii.gz\", \"./SABS/sabs_CT_normalized_672/image_*.nii.gz\")\n",
+ "SEG_NAMES = (\"./SABS/sabs_CT_normalized/label_*.nii.gz\", \"./SABS/sabs_CT_normalized_672/label_*.nii.gz\")\n",
+ "for IMG_BNAME, SEG_BNAME in zip(IMG_BNAMES, SEG_NAMES):\n",
+ " imgs = glob.glob(IMG_BNAME)\n",
+ " segs = glob.glob(SEG_BNAME)\n",
+ " imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n",
+ " segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n",
+ " for img, seg in zip(imgs, segs):\n",
+ " print(img, seg)\n",
+ "\n",
+ " classmap = {}\n",
+ " LABEL_NAME = [\"BGD\", \"SPLEEN\", \"KID_R\", \"KID_l\", \"GALLBLADDER\", \"ESOPHAGUS\", \"LIVER\", \"STOMACH\", \"AORTA\", \"IVC\", \"PS_VEIN\", \"PANCREAS\", \"AG_R\", \"AG_L\"] \n",
+ "\n",
+ " MIN_TP=1 # minimum number of true positive pixels in a slice\n",
+ "\n",
+ " fid = os.path.dirname(IMG_BNAME) + f'/classmap_{MIN_TP}.json'\n",
+ " for _lb in LABEL_NAME:\n",
+ " classmap[_lb] = {}\n",
+ " for pid in range(len(segs)):\n",
+ " classmap[_lb][str(pid)] = []\n",
+ "\n",
+ " for pid, seg in enumerate(segs):\n",
+ " # lb_vol = nio.read_nii_bysitk(seg)\n",
+ " lb_vol = sitk.GetArrayFromImage(sitk.ReadImage(seg))\n",
+ " n_slice = lb_vol.shape[0]\n",
+ " for slc in range(n_slice):\n",
+ " for cls in range(len(LABEL_NAME)):\n",
+ " if cls in lb_vol[slc, ...]:\n",
+ " if np.sum( lb_vol[slc, ...] == cls) >= MIN_TP:\n",
+ " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n",
+ " print(f'pid {str(pid)} finished!')\n",
+ " \n",
+ " with open(fid, 'w') as fopen:\n",
+ " json.dump(classmap, fopen)\n",
+ " fopen.close() \n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MRI Image Normalization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "## PLEASE RUN dcm_img_to_nii.sh to convert dicom to nii.gz\n",
+ "! ./dcm_img_to_nii.sh"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [316. 316. 35.99999911]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [316. 316. 35.99999911]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [316. 316. 35.99999911]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [316. 316. 35.99999911]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [316. 316. 35.99999911]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [316. 316. 35.99999911]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_1.nii.gz has been saved\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [348. 348. 35.99999911]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [348. 348. 35.99999911]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [348. 348. 35.99999911]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [348. 348. 35.99999911]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [348. 348. 35.99999911]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 36) -> [348. 348. 35.99999911]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_10.nii.gz has been saved\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_13.nii.gz has been saved\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_15.nii.gz has been saved\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_19.nii.gz has been saved\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_2.nii.gz has been saved\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 26) -> [348. 348. 30.38961039]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_20.nii.gz has been saved\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 31) -> [340. 340. 32.20779221]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 31) -> [340. 340. 32.20779221]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 31) -> [340. 340. 32.20779221]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 31) -> [340. 340. 32.20779221]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 31) -> [340. 340. 32.20779221]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 31) -> [340. 340. 32.20779221]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_21.nii.gz has been saved\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 33) -> [356. 356. 37.28571347]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 33) -> [356. 356. 37.28571347]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 33) -> [356. 356. 37.28571347]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 33) -> [356. 356. 37.28571347]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 33) -> [356. 356. 37.28571347]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 33) -> [356. 356. 37.28571347]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_22.nii.gz has been saved\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 30) -> [348. 348. 35.06493506]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_3.nii.gz has been saved\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [332. 332. 33.8961039]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [332. 332. 33.8961039]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [332. 332. 33.8961039]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [332. 332. 33.8961039]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [332. 332. 33.8961039]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [332. 332. 33.8961039]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_31.nii.gz has been saved\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 39) -> [372. 372. 45.58441558]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 39) -> [372. 372. 45.58441558]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 39) -> [372. 372. 45.58441558]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 39) -> [372. 372. 45.58441558]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 39) -> [372. 372. 45.58441558]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 39) -> [372. 372. 45.58441558]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_32.nii.gz has been saved\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [356. 356. 33.8961039]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [356. 356. 33.8961039]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [356. 356. 33.8961039]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [356. 356. 33.8961039]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [356. 356. 33.8961039]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 29) -> [356. 356. 33.8961039]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_33.nii.gz has been saved\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [356. 356. 34.28571503]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [356. 356. 34.28571503]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [356. 356. 34.28571503]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [356. 356. 34.28571503]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [356. 356. 34.28571503]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [356. 356. 34.28571503]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_34.nii.gz has been saved\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 38) -> [332. 332. 44.41558442]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 38) -> [332. 332. 44.41558442]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 38) -> [332. 332. 44.41558442]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 38) -> [332. 332. 44.41558442]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 38) -> [332. 332. 44.41558442]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 38) -> [332. 332. 44.41558442]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_36.nii.gz has been saved\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 32) -> [300. 300. 37.4025974]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 32) -> [300. 300. 37.4025974]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 32) -> [300. 300. 37.4025974]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 32) -> [300. 300. 37.4025974]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 32) -> [300. 300. 37.4025974]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 32) -> [300. 300. 37.4025974]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_37.nii.gz has been saved\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 34) -> [348. 348. 39.74025974]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 34) -> [348. 348. 39.74025974]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 34) -> [348. 348. 39.74025974]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 34) -> [348. 348. 39.74025974]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 34) -> [348. 348. 39.74025974]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (320, 320, 34) -> [348. 348. 39.74025974]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_38.nii.gz has been saved\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 26) -> [324. 324. 30.38961039]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_39.nii.gz has been saved\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [340. 340. 35.06493506]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [340. 340. 35.06493506]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [340. 340. 35.06493506]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [340. 340. 35.06493506]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [340. 340. 35.06493506]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (256, 256, 30) -> [340. 340. 35.06493506]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_5.nii.gz has been saved\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (288, 288, 32) -> [324. 324. 33.24675325]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (288, 288, 32) -> [324. 324. 33.24675325]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (288, 288, 32) -> [324. 324. 33.24675325]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (288, 288, 32) -> [324. 324. 33.24675325]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (288, 288, 32) -> [324. 324. 33.24675325]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n",
+ "Size (288, 288, 32) -> [324. 324. 33.24675325]\n",
+ "./CHAOST2/chaos_MR_T2_normalized/image_8.nii.gz has been saved\n"
+ ]
+ }
+ ],
+ "source": [
+ "import copy\n",
+ "\n",
+ "IMG_FOLDER = \"./CHAOST2/niis/T2SPIR\" #, path of nii-like images from step 1\n",
+ "OUT_FOLDER=\"./CHAOST2/chaos_MR_T2_normalized/\" # output directory\n",
+ "\n",
+ "imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')\n",
+ "imgs = [ fid for fid in sorted(imgs) ]\n",
+ "segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]\n",
+ "\n",
+ "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]\n",
+ "for img, seg in zip(imgs, segs):\n",
+ " print(img, seg)\n",
+ "\n",
+ "os.makedirs(OUT_FOLDER, exist_ok = True)\n",
+ " \n",
+ "HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect\n",
+ "\n",
+ "NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing\n",
+ "\n",
+ "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n",
+ "\n",
+ " resample_flg = True\n",
+ "\n",
+ " img_obj = sitk.ReadImage( img_fid )\n",
+ " seg_obj = sitk.ReadImage( seg_fid )\n",
+ "\n",
+ " array = sitk.GetArrayFromImage(img_obj)\n",
+ "\n",
+ " # cut histogram\n",
+ " hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))\n",
+ " array[array > hir] = hir\n",
+ "\n",
+ " his_img_o = sitk.GetImageFromArray(array)\n",
+ " his_img_o = copy_spacing_ori(img_obj, his_img_o)\n",
+ "\n",
+ " # resampling\n",
+ " img_spa_ori = img_obj.GetSpacing()\n",
+ " res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],\n",
+ " interpolator = sitk.sitkLinear, logging = True)\n",
+ " ## label\n",
+ " lb_arr = sitk.GetArrayFromImage(seg_obj)\n",
+ "\n",
+ " # resampling\n",
+ " res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,\n",
+ " ref_img = None, logging = True)\n",
+ "\n",
+ " # crop out rois\n",
+ " res_img_a = s2n(res_img_o)\n",
+ "\n",
+ " crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],\n",
+ " referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],\n",
+ " padval = res_img_a.min(), only_2d = True).transpose(2,0,1)\n",
+ "\n",
+ " out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))\n",
+ "\n",
+ " res_lb_a = s2n(res_lb_o)\n",
+ "\n",
+ " crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],\n",
+ " referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],\n",
+ " padval = 0, only_2d = True).transpose(2,0,1)\n",
+ "\n",
+ " out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))\n",
+ "\n",
+ "\n",
+ " out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )\n",
+ " out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) \n",
+ "\n",
+ " # then save pre-processed images\n",
+ " sitk.WriteImage(out_img_obj, out_img_fid, True) \n",
+ " sitk.WriteImage(out_lb_obj, out_lb_fid, True) \n",
+ " print(\"{} has been saved\".format(out_img_fid))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MRI Resampling and ROI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 111,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "./CHAOST2/niis/T2SPIR/image_1.nii.gz ./CHAOST2/niis/T2SPIR/label_1.nii.gz\n",
+ "(36, 256, 256) label shape (36, 256, 256)\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "./SABS/sabs_CT_normalized/image_1.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_1.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_10.nii.gz ./CHAOST2/niis/T2SPIR/label_10.nii.gz\n",
+ "(36, 256, 256) label shape (36, 256, 256)\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n",
+ "Size (254, 254, 36) -> [672. 672. 36.]\n",
+ "./SABS/sabs_CT_normalized/image_10.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_10.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_13.nii.gz ./CHAOST2/niis/T2SPIR/label_13.nii.gz\n",
+ "(30, 320, 320) label shape (30, 320, 320)\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "./SABS/sabs_CT_normalized/image_13.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_13.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_15.nii.gz ./CHAOST2/niis/T2SPIR/label_15.nii.gz\n",
+ "(26, 256, 256) label shape (26, 256, 256)\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "./SABS/sabs_CT_normalized/image_15.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_15.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_19.nii.gz ./CHAOST2/niis/T2SPIR/label_19.nii.gz\n",
+ "(30, 320, 320) label shape (30, 320, 320)\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "./SABS/sabs_CT_normalized/image_19.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_19.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_2.nii.gz ./CHAOST2/niis/T2SPIR/label_2.nii.gz\n",
+ "(26, 320, 320) label shape (26, 320, 320)\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "./SABS/sabs_CT_normalized/image_2.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_2.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_20.nii.gz ./CHAOST2/niis/T2SPIR/label_20.nii.gz\n",
+ "(26, 320, 320) label shape (26, 320, 320)\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 26) -> [672. 672. 26.]\n",
+ "./SABS/sabs_CT_normalized/image_20.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_20.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_21.nii.gz ./CHAOST2/niis/T2SPIR/label_21.nii.gz\n",
+ "(31, 256, 256) label shape (31, 256, 256)\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n",
+ "Size (254, 254, 31) -> [672. 672. 31.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n",
+ "Size (254, 254, 31) -> [672. 672. 31.]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n",
+ "Size (254, 254, 31) -> [672. 672. 31.]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n",
+ "Size (254, 254, 31) -> [672. 672. 31.]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n",
+ "Size (254, 254, 31) -> [672. 672. 31.]\n",
+ "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n",
+ "Size (254, 254, 31) -> [672. 672. 31.]\n",
+ "./SABS/sabs_CT_normalized/image_21.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_21.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_22.nii.gz ./CHAOST2/niis/T2SPIR/label_22.nii.gz\n",
+ "(33, 256, 256) label shape (33, 256, 256)\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n",
+ "Size (254, 254, 33) -> [672. 672. 33.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n",
+ "Size (254, 254, 33) -> [672. 672. 33.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n",
+ "Size (254, 254, 33) -> [672. 672. 33.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n",
+ "Size (254, 254, 33) -> [672. 672. 33.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n",
+ "Size (254, 254, 33) -> [672. 672. 33.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n",
+ "Size (254, 254, 33) -> [672. 672. 33.]\n",
+ "./SABS/sabs_CT_normalized/image_22.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_22.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_3.nii.gz ./CHAOST2/niis/T2SPIR/label_3.nii.gz\n",
+ "(30, 320, 320) label shape (30, 320, 320)\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 30) -> [672. 672. 30.]\n",
+ "./SABS/sabs_CT_normalized/image_3.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_3.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_31.nii.gz ./CHAOST2/niis/T2SPIR/label_31.nii.gz\n",
+ "(29, 256, 256) label shape (29, 256, 256)\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "./SABS/sabs_CT_normalized/image_31.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_31.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_32.nii.gz ./CHAOST2/niis/T2SPIR/label_32.nii.gz\n",
+ "(39, 256, 256) label shape (39, 256, 256)\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n",
+ "Size (254, 254, 39) -> [672. 672. 39.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n",
+ "Size (254, 254, 39) -> [672. 672. 39.]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n",
+ "Size (254, 254, 39) -> [672. 672. 39.]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n",
+ "Size (254, 254, 39) -> [672. 672. 39.]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n",
+ "Size (254, 254, 39) -> [672. 672. 39.]\n",
+ "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n",
+ "Size (254, 254, 39) -> [672. 672. 39.]\n",
+ "./SABS/sabs_CT_normalized/image_32.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_32.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_33.nii.gz ./CHAOST2/niis/T2SPIR/label_33.nii.gz\n",
+ "(29, 256, 256) label shape (29, 256, 256)\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n",
+ "Size (254, 254, 29) -> [672. 672. 29.]\n",
+ "./SABS/sabs_CT_normalized/image_33.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_33.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_34.nii.gz ./CHAOST2/niis/T2SPIR/label_34.nii.gz\n",
+ "(30, 256, 256) label shape (30, 256, 256)\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "./SABS/sabs_CT_normalized/image_34.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_34.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_36.nii.gz ./CHAOST2/niis/T2SPIR/label_36.nii.gz\n",
+ "(38, 256, 256) label shape (38, 256, 256)\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 38) -> [672. 672. 38.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 38) -> [672. 672. 38.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 38) -> [672. 672. 38.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 38) -> [672. 672. 38.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 38) -> [672. 672. 38.]\n",
+ "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n",
+ "Size (254, 254, 38) -> [672. 672. 38.]\n",
+ "./SABS/sabs_CT_normalized/image_36.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_36.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_37.nii.gz ./CHAOST2/niis/T2SPIR/label_37.nii.gz\n",
+ "(32, 256, 256) label shape (32, 256, 256)\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n",
+ "Size (254, 254, 32) -> [672. 672. 32.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n",
+ "Size (254, 254, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n",
+ "Size (254, 254, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n",
+ "Size (254, 254, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n",
+ "Size (254, 254, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n",
+ "Size (254, 254, 32) -> [672. 672. 32.]\n",
+ "./SABS/sabs_CT_normalized/image_37.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_37.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_38.nii.gz ./CHAOST2/niis/T2SPIR/label_38.nii.gz\n",
+ "(34, 320, 320) label shape (34, 320, 320)\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 34) -> [672. 672. 34.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 34) -> [672. 672. 34.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 34) -> [672. 672. 34.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 34) -> [672. 672. 34.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 34) -> [672. 672. 34.]\n",
+ "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n",
+ "Size (318, 318, 34) -> [672. 672. 34.]\n",
+ "./SABS/sabs_CT_normalized/image_38.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_38.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_39.nii.gz ./CHAOST2/niis/T2SPIR/label_39.nii.gz\n",
+ "(26, 256, 256) label shape (26, 256, 256)\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n",
+ "Size (254, 254, 26) -> [672. 672. 26.]\n",
+ "./SABS/sabs_CT_normalized/image_39.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_39.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_5.nii.gz ./CHAOST2/niis/T2SPIR/label_5.nii.gz\n",
+ "(30, 256, 256) label shape (30, 256, 256)\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n",
+ "Size (254, 254, 30) -> [672. 672. 30.]\n",
+ "./SABS/sabs_CT_normalized/image_5.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_5.nii.gz has been saved\n",
+ "./CHAOST2/niis/T2SPIR/image_8.nii.gz ./CHAOST2/niis/T2SPIR/label_8.nii.gz\n",
+ "(32, 288, 288) label shape (32, 288, 288)\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n",
+ "Size (286, 286, 32) -> [672. 672. 32.]\n",
+ "Label values: [0 1 2 3 4]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n",
+ "Size (286, 286, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n",
+ "Size (286, 286, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n",
+ "Size (286, 286, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n",
+ "Size (286, 286, 32) -> [672. 672. 32.]\n",
+ "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n",
+ "Size (286, 286, 32) -> [672. 672. 32.]\n",
+ "./SABS/sabs_CT_normalized/image_8.nii.gz has been saved\n",
+ "./SABS/sabs_CT_normalized/label_8.nii.gz has been saved\n"
+ ]
+ }
+ ],
+ "source": [
+ "# SPA_FAC = (256 - 2 * BD_BIAS) / 512 # spacing factor\n",
+ "BD_BIAS = 1\n",
+ "scan_dir = OUT_FOLDER\n",
+ "for res in (256, 672):\n",
+ " if res == 672:\n",
+ " scan_dir += \"_672\"\n",
+ " resample_imgs(imgs, segs, pids, scan_dir,\n",
+ " BD_BIAS, SPA_FAC=None, required_res=res)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MRI Classmap Generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pid 1 finished!\n",
+ "pid 2 finished!\n",
+ "pid 3 finished!\n",
+ "pid 5 finished!\n",
+ "pid 8 finished!\n",
+ "pid 10 finished!\n",
+ "pid 13 finished!\n",
+ "pid 15 finished!\n",
+ "pid 19 finished!\n",
+ "pid 20 finished!\n",
+ "pid 21 finished!\n",
+ "pid 22 finished!\n",
+ "pid 31 finished!\n",
+ "pid 32 finished!\n",
+ "pid 33 finished!\n",
+ "pid 34 finished!\n",
+ "pid 36 finished!\n",
+ "pid 37 finished!\n",
+ "pid 38 finished!\n",
+ "pid 39 finished!\n"
+ ]
+ }
+ ],
+ "source": [
+ "IMG_BNAMES = (\"./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz\", \"./CHAOST2/chaos_MR_T2_normalized_672/image_*.nii.gz\")\n",
+ "SEG_NAMES = (\"./CHAOST2/chaos_MR_T2_normalized/label_*.nii.gz\", \"./CHAOST2/chaos_MR_T2_normalized_672/label_*.nii.gz\")\n",
+ "\n",
+ "for IMG_BNAME, SEG_BNAME in zip(IMG_BNAMES, SEG_NAMES):\n",
+ " imgs = glob.glob(IMG_BNAME)\n",
+ " segs = glob.glob(SEG_BNAME)\n",
+ " imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n",
+ " segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n",
+ "\n",
+ "\n",
+ " classmap = {}\n",
+ " LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n",
+ "\n",
+ " MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n",
+ "\n",
+ " fid = os.path.join(OUT_FOLDER,f'.classmap_{MIN_TP}.json') # name of the output file. \n",
+ " for _lb in LABEL_NAME:\n",
+ " classmap[_lb] = {}\n",
+ " for _sid in segs:\n",
+ " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n",
+ " classmap[_lb][pid] = []\n",
+ "\n",
+ " for seg in segs:\n",
+ " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n",
+ " lb_vol = sitk.GetArrayFromImage(sitk.ReadImage(seg))\n",
+ " n_slice = lb_vol.shape[0]\n",
+ " for slc in range(n_slice):\n",
+ " for cls in range(len(LABEL_NAME)):\n",
+ " if cls in lb_vol[slc, ...]:\n",
+ " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n",
+ " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n",
+ " print(f'pid {str(pid)} finished!')\n",
+ " \n",
+ " with open(fid, 'w') as fopen:\n",
+ " json.dump(classmap, fopen)\n",
+ " fopen.close() \n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Psuedo label generation for Encoder Finetuning"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import copy\n",
+ "import skimage\n",
+ "\n",
+ "from skimage.segmentation import slic\n",
+ "from skimage.segmentation import mark_boundaries\n",
+ "from skimage.util import img_as_float\n",
+ "from skimage.measure import label \n",
+ "import scipy.ndimage.morphology as snm\n",
+ "from skimage import io\n",
+ "import argparse\n",
+ "\n",
+ "\n",
+ "to01 = lambda x: (x - x.min()) / (x.max() - x.min())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Summary\n",
+ "\n",
+ "a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background\n",
+ "\n",
+ "b. Generate superpixels as pseudolabels\n",
+ "\n",
+ "Configurations of pseudlabels\n",
+ "\n",
+ "default setting of minimum superpixel sizes\n",
+ "`segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)`\n",
+ "\n",
+ "you can also try other configs\n",
+ "`segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting\n",
+ "\n",
+ "# wrapper for process 3d image in 2d\n",
+ "def superpix_vol(img, method = 'fezlen', **kwargs):\n",
+ " \"\"\"\n",
+ " loop through the entire volume\n",
+ " assuming image with axis z, x, y\n",
+ " \"\"\"\n",
+ " if method =='fezlen':\n",
+ " seg_func = skimage.segmentation.felzenszwalb\n",
+ " else:\n",
+ " raise NotImplementedError\n",
+ " \n",
+ " out_vol = np.zeros(img.shape)\n",
+ " for ii in range(img.shape[0]):\n",
+ " if MODE == 'MIDDLE':\n",
+ " segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)\n",
+ " else:\n",
+ " raise NotImplementedError\n",
+ " out_vol[ii, ...] = segs\n",
+ " \n",
+ " return out_vol\n",
+ "\n",
+ "# thresholding the intensity values to get a binary mask of the patient\n",
+ "def fg_mask2d(img_2d, thresh): # change this by your need\n",
+ " mask_map = np.float32(img_2d > thresh)\n",
+ " \n",
+ " def getLargestCC(segmentation): # largest connected components\n",
+ " labels = label(segmentation)\n",
+ " assert( labels.max() != 0 ) # assume at least 1 CC\n",
+ " largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1\n",
+ " return largestCC\n",
+ " if mask_map.max() < 0.999:\n",
+ " return mask_map\n",
+ " else:\n",
+ " post_mask = getLargestCC(mask_map)\n",
+ " fill_mask = snm.binary_fill_holes(post_mask)\n",
+ " return fill_mask\n",
+ "\n",
+ "# remove superpixels within the empty regions\n",
+ "def superpix_masking(raw_seg2d, mask2d):\n",
+ " raw_seg2d = np.int32(raw_seg2d)\n",
+ " lbvs = np.unique(raw_seg2d)\n",
+ " max_lb = lbvs.max()\n",
+ " raw_seg2d[raw_seg2d == 0] = max_lb + 1\n",
+ " lbvs = list(lbvs)\n",
+ " lbvs.append( max_lb )\n",
+ " raw_seg2d = raw_seg2d * mask2d\n",
+ " lb_new = 1\n",
+ " out_seg2d = np.zeros(raw_seg2d.shape)\n",
+ " for lbv in lbvs:\n",
+ " if lbv == 0:\n",
+ " continue\n",
+ " else:\n",
+ " out_seg2d[raw_seg2d == lbv] = lb_new\n",
+ " lb_new += 1\n",
+ " \n",
+ " return out_seg2d\n",
+ " \n",
+ "def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):\n",
+ " raw_seg = superpix_vol(img)\n",
+ " fg_mask_vol = np.zeros(raw_seg.shape)\n",
+ " processed_seg_vol = np.zeros(raw_seg.shape)\n",
+ " for ii in range(raw_seg.shape[0]):\n",
+ " if verbose:\n",
+ " print(\"doing {} slice\".format(ii))\n",
+ " _fgm = fg_mask2d(img[ii, ...], fg_thresh )\n",
+ " _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)\n",
+ " fg_mask_vol[ii] = _fgm\n",
+ " processed_seg_vol[ii] = _out_seg\n",
+ " return fg_mask_vol, processed_seg_vol\n",
+ " \n",
+ "# copy spacing and orientation info between sitk objects\n",
+ "def copy_info(src, dst):\n",
+ " dst.SetSpacing(src.GetSpacing())\n",
+ " dst.SetOrigin(src.GetOrigin())\n",
+ " dst.SetDirection(src.GetDirection())\n",
+ " # dst.CopyInfomation(src)\n",
+ " return dst\n",
+ "\n",
+ "\n",
+ "def strip_(img, lb):\n",
+ " img = np.int32(img)\n",
+ " if isinstance(lb, float):\n",
+ " lb = int(lb)\n",
+ " return np.float32(img == lb) * float(lb)\n",
+ " elif isinstance(lb, list):\n",
+ " out = np.zeros(img.shape)\n",
+ " for _lb in lb:\n",
+ " out += np.float32(img == int(_lb)) * float(_lb)\n",
+ " \n",
+ " return out\n",
+ " else:\n",
+ " raise Exception"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DATASET_CONFIG = {'SABS':{\n",
+ " 'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',\n",
+ " 'out_dir': './SABS/sabs_CT_normalized',\n",
+ " 'fg_thresh': 1e-4\n",
+ " },\n",
+ " 'CHAOST2':{\n",
+ " 'img_bname': f'./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',\n",
+ " 'out_dir': './CHAOST2/chaos_MR_T2_normalized',\n",
+ " 'fg_thresh': 1e-4 + 50\n",
+ " },\n",
+ " 'SABS_672':{\n",
+ " 'img_bname': f'./SABS/sabs_CT_normalized_672/image_*.nii.gz',\n",
+ " 'out_dir': './SABS/sabs_CT_normalized_672',\n",
+ " 'fg_thresh': 1e-4\n",
+ " },\n",
+ " 'CHAOST2_672':{\n",
+ " 'img_bname': f'./CHAOST2/chaos_MR_T2_normalized_672/image_*.nii.gz',\n",
+ " 'out_dir': './CHAOST2/chaos_MR_T2_normalized_672',\n",
+ " 'fg_thresh': 1e-4 + 50\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "for DOMAIN in DATASET_CONFIG.keys():\n",
+ " img_bname = DATASET_CONFIG[DOMAIN]['img_bname']\n",
+ " imgs = glob.glob(img_bname)\n",
+ " out_dir = DATASET_CONFIG[DOMAIN]['out_dir']\n",
+ "\n",
+ " imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )\n",
+ " print(imgs)\n",
+ "\n",
+ " # Generate pseudolabels for every image and save them\n",
+ " for img_fid in imgs:\n",
+ " # img_fid = imgs[0]\n",
+ "\n",
+ " idx = os.path.basename(img_fid).split(\"_\")[-1].split(\".nii.gz\")[0]\n",
+ " im_obj = sitk.ReadImage(img_fid)\n",
+ "\n",
+ " out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )\n",
+ " out_fg_o = sitk.GetImageFromArray(out_fg ) \n",
+ " out_seg_o = sitk.GetImageFromArray(out_seg )\n",
+ "\n",
+ " out_fg_o = copy_info(im_obj, out_fg_o)\n",
+ " out_seg_o = copy_info(im_obj, out_seg_o)\n",
+ " seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')\n",
+ " msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')\n",
+ " sitk.WriteImage(out_fg_o, msk_fid)\n",
+ " sitk.WriteImage(out_seg_o, seg_fid)\n",
+ " print(f'image with id {idx} has finished')\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "lev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dataloaders/GenericSuperDatasetv2.py b/dataloaders/GenericSuperDatasetv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cf04917ecb778cf49a88af71d459e35891646ae
--- /dev/null
+++ b/dataloaders/GenericSuperDatasetv2.py
@@ -0,0 +1,445 @@
+"""
+Dataset for training with pseudolabels
+TODO:
+1. Merge with manual annotated dataset
+2. superpixel_scale -> superpix_config, feed like a dict
+"""
+import glob
+import numpy as np
+import dataloaders.augutils as myaug
+import torch
+import random
+import os
+import copy
+import platform
+import json
+import re
+import cv2
+from dataloaders.common import BaseDataset, Subset
+from dataloaders.dataset_utils import*
+from pdb import set_trace
+from util.utils import CircularList
+from util.consts import IMG_SIZE
+
+class SuperpixelDataset(BaseDataset):
+ def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs):
+ """
+ Pseudolabel dataset
+ Args:
+ which_dataset: name of the dataset to use
+ base_dir: directory of dataset
+ idx_split: index of data split as we will do cross validation
+ mode: 'train', 'val'.
+ nsup: number of scans used as support. currently idle for superpixel dataset
+ transforms: data transform (augmentation) function
+ scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
+ num_rep: Number of augmentation applied for a same pseudolabel
+ tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
+ fix_length: fix the length of dataset
+ exclude_list: Labels to be excluded
+ superpix_scale: config of superpixels
+ """
+ super(SuperpixelDataset, self).__init__(base_dir)
+
+ self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
+ self.sep = DATASET_INFO[which_dataset]['_SEP']
+ self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME']
+ self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
+
+ self.image_size = image_size
+ self.transforms = transforms
+ self.is_train = True if mode == 'train' else False
+ self.supervised_train = supervised_train
+ if self.supervised_train and len(train_list) == 0:
+ raise Exception('Please provide training labels')
+ # assert mode == 'train'
+ self.fix_length = fix_length
+ if self.supervised_train:
+ # self.nclass = len(self.real_label_name)
+ self.nclass = len(self.pseu_label_name)
+ else:
+ self.nclass = len(self.pseu_label_name)
+ self.num_rep = num_rep
+ self.tile_z_dim = tile_z_dim
+ self.use_3_slices = use_3_slices
+ if tile_z_dim > 1 and self.use_3_slices:
+ raise Exception("tile_z_dim and use_3_slices shouldn't be used together")
+
+ # find scans in the data folder
+ self.nsup = nsup
+ self.base_dir = base_dir
+ self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
+ self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x)))
+
+ # experiment configs
+ self.exclude_lbs = exclude_list
+ self.train_list = train_list
+ self.superpix_scale = superpix_scale
+ if len(exclude_list) > 0:
+ print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
+ self.idx_split = idx_split
+ self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
+ self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
+ self.scan_per_load = scan_per_load
+
+ self.info_by_scan = None
+ self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
+ self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std)
+
+ if self.is_train:
+ if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch
+ self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
+ else: # load the entire set without a buffer
+ self.pid_curr_load = self.scan_ids
+ elif mode == 'val':
+ self.pid_curr_load = self.scan_ids
+ else:
+ raise Exception
+
+ self.use_clahe = False
+ if kwargs['use_clahe']:
+ self.use_clahe = True
+ clip_limit = 4.0 if self.img_modality == 'MR' else 2.0
+ self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7))
+
+ self.actual_dataset = self.read_dataset()
+ self.size = len(self.actual_dataset)
+ self.overall_slice_by_cls = self.read_classfiles()
+
+ print("###### Initial scans loaded: ######")
+ print(self.pid_curr_load)
+
+ def get_scanids(self, mode, idx_split):
+ """
+ Load scans by train-test split
+ leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one
+ Args:
+ idx_split: index for spliting cross-validation folds
+ """
+ val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
+ if mode == 'train':
+ return [ ii for ii in self.img_pids if ii not in val_ids ]
+ elif mode == 'val':
+ return val_ids
+
+ def reload_buffer(self):
+ """
+ Reload a only portion of the entire dataset, if the dataset is too large
+ 1. delete original buffer
+ 2. update self.ids_this_batch
+ 3. update other internel variables like __len__
+ """
+ if self.scan_per_load <= 0:
+ print("We are not using the reload buffer, doing notiong")
+ return -1
+
+ del self.actual_dataset
+ del self.info_by_scan
+
+ self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
+ self.actual_dataset = self.read_dataset()
+ self.size = len(self.actual_dataset)
+ self.update_subclass_lookup()
+ print(f'Loader buffer reloaded with a new size of {self.size} slices')
+
+ def organize_sample_fids(self):
+ out_list = {}
+ for curr_id in self.scan_ids:
+ curr_dict = {}
+
+ _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
+ _lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz')
+ _gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
+
+ curr_dict["img_fid"] = _img_fid
+ curr_dict["lbs_fid"] = _lb_fid
+ curr_dict["gt_lbs_fid"] = _gt_lb_fid
+ out_list[str(curr_id)] = curr_dict
+ return out_list
+
+ def read_dataset(self):
+ """
+ Read images into memory and store them in 2D
+ Build tables for the position of an individual 2D slice in the entire dataset
+ """
+ out_list = []
+ self.scan_z_idx = {}
+ self.info_by_scan = {} # meta data of each scan
+ glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
+
+ for scan_id, itm in self.img_lb_fids.items():
+ if scan_id not in self.pid_curr_load:
+ continue
+
+ img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
+ # read connected graph of labels
+ if self.use_clahe:
+ # img = nself.clahe.apply(img.astype(np.uint8))
+ if self.img_modality == 'MR':
+ img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0)
+ img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
+
+ img = img.transpose(1,2,0)
+ self.info_by_scan[scan_id] = _info
+
+ img = np.float32(img)
+ img = self.norm_func(img)
+ self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
+
+ if self.supervised_train:
+ lb = read_nii_bysitk(itm["gt_lbs_fid"])
+ else:
+ lb = read_nii_bysitk(itm["lbs_fid"])
+ lb = lb.transpose(1,2,0)
+ lb = np.int32(lb)
+
+ # resize img and lb to self.image_size
+ img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
+ lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
+
+ # format of slices: [axial_H x axial_W x Z]
+ if self.supervised_train:
+ # remove all images that dont have the training labels
+ del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))]
+ # create an new img and lb without indices in del_indices
+ new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices)))
+ new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices)))
+ new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)]
+ new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)]
+
+ img = new_img
+ lb = new_lb
+ a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0]
+
+ nframes = img.shape[-1]
+ assert img.shape[-1] == lb.shape[-1]
+ base_idx = img.shape[-1] // 2 # index of the middle slice
+
+ # re-organize 3D images into 2D slices and record essential information for each slice
+ out_list.append( {"img": img[..., 0: 1],
+ "lb":lb[..., 0: 0 + 1],
+ "sup_max_cls": lb[..., 0: 0 + 1].max(),
+ "is_start": True,
+ "is_end": False,
+ "nframe": nframes,
+ "scan_id": scan_id,
+ "z_id":0,
+ })
+
+ self.scan_z_idx[scan_id][0] = glb_idx
+ glb_idx += 1
+
+ for ii in range(1, img.shape[-1] - 1):
+ out_list.append( {"img": img[..., ii: ii + 1],
+ "lb":lb[..., ii: ii + 1],
+ "is_start": False,
+ "is_end": False,
+ "sup_max_cls": lb[..., ii: ii + 1].max(),
+ "nframe": nframes,
+ "scan_id": scan_id,
+ "z_id": ii,
+ })
+ self.scan_z_idx[scan_id][ii] = glb_idx
+ glb_idx += 1
+
+ ii += 1 # last slice of a 3D volume
+ out_list.append( {"img": img[..., ii: ii + 1],
+ "lb":lb[..., ii: ii+ 1],
+ "is_start": False,
+ "is_end": True,
+ "sup_max_cls": lb[..., ii: ii + 1].max(),
+ "nframe": nframes,
+ "scan_id": scan_id,
+ "z_id": ii,
+ })
+
+ self.scan_z_idx[scan_id][ii] = glb_idx
+ glb_idx += 1
+
+ return out_list
+
+ def read_classfiles(self):
+ """
+ Load the scan-slice-class indexing file
+ """
+ with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
+ cls_map = json.load( fopen)
+ fopen.close()
+
+ with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
+ self.tp1_cls_map = json.load( fopen)
+ fopen.close()
+
+ return cls_map
+
+ def get_superpixels_similarity(self, sp1, sp2):
+ pass
+
+ def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None):
+ if bi_val is None:
+ # bi_val = np.random.randint(1, sup_max_cls)
+ bi_val = random.choice(list(np.unique(super_map)))
+ if conn_graph is not None and img is not None:
+ # get number of neighbors of bi_val
+ neighbors = conn_graph[bi_val]
+ # pick a random number of neighbors and merge them
+ n_neighbors = np.random.randint(0, len(neighbors))
+ try:
+ neighbors = random.sample(neighbors, n_neighbors)
+ except TypeError:
+ neighbors = []
+ # merge neighbors
+ super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map)
+ return np.float32(super_map == bi_val)
+
+ def supcls_pick(self, super_map):
+ return random.choice(list(np.unique(super_map)))
+
+ def get_3_slice_adjacent_image(self, image_t, index):
+ curr_dict = self.actual_dataset[index]
+ prev_image = np.zeros_like(image_t)
+
+ if index > 1 and not curr_dict["is_start"]:
+ prev_dict = self.actual_dataset[index - 1]
+ prev_image = prev_dict["img"]
+
+ next_image = np.zeros_like(image_t)
+ if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
+ next_dict = self.actual_dataset[index + 1]
+ next_image = next_dict["img"]
+
+ image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
+
+ return image_t
+
+ def __getitem__(self, index):
+ index = index % len(self.actual_dataset)
+ curr_dict = self.actual_dataset[index]
+ sup_max_cls = curr_dict['sup_max_cls']
+ if sup_max_cls < 1:
+ return self.__getitem__(index + 1)
+
+ image_t = curr_dict["img"]
+ label_raw = curr_dict["lb"]
+
+ if self.use_3_slices:
+ image_t = self.get_3_slice_adjacent_image(image_t, index)
+
+ for _ex_cls in self.exclude_lbs:
+ if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen
+ return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
+
+ if self.supervised_train:
+ superpix_label = -1
+ label_t = np.float32(label_raw)
+
+ lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list)))
+ label_t[label_t != lb_id] = 0
+ label_t[label_t == lb_id] = 1
+
+ else:
+ superpix_label = self.supcls_pick(label_raw)
+ label_t = np.float32(label_raw == superpix_label)
+
+ pair_buffer = []
+
+ comp = np.concatenate( [image_t, label_t], axis = -1 )
+
+ for ii in range(self.num_rep):
+ if self.transforms is not None:
+ img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False)
+ else:
+ img, lb = comp[:, :, 0:1], comp[:, :, 1:2]
+ # if ii % 2 == 0:
+ # label_raw = lb
+ # lb = lb == superpix_label
+
+ img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float()
+ lb = torch.from_numpy( lb.squeeze(-1)).float()
+
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
+
+ is_start = curr_dict["is_start"]
+ is_end = curr_dict["is_end"]
+ nframe = np.int32(curr_dict["nframe"])
+ scan_id = curr_dict["scan_id"]
+ z_id = curr_dict["z_id"]
+
+ sample = {"image": img,
+ "label":lb,
+ "is_start": is_start,
+ "is_end": is_end,
+ "nframe": nframe,
+ "scan_id": scan_id,
+ "z_id": z_id
+ }
+
+ # Add auxiliary attributes
+ if self.aux_attrib is not None:
+ for key_prefix in self.aux_attrib:
+ # Process the data sample, create new attributes and save them in a dictionary
+ aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
+ for key_suffix in aux_attrib_val:
+ # one function may create multiple attributes, so we need suffix to distinguish them
+ sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
+ pair_buffer.append(sample)
+
+ support_images = []
+ support_mask = []
+ support_class = []
+
+ query_images = []
+ query_labels = []
+ query_class = []
+
+ for idx, itm in enumerate(pair_buffer):
+ if idx % 2 == 0:
+ support_images.append(itm["image"])
+ support_class.append(1) # pseudolabel class
+ support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] ))
+ else:
+ query_images.append(itm["image"])
+ query_class.append(1)
+ query_labels.append( itm["label"])
+
+ return {'class_ids': [support_class],
+ 'support_images': [support_images], #
+ 'superpix_label': superpix_label,
+ 'superpix_label_raw': label_raw[:,:,0],
+ 'support_mask': [support_mask],
+ 'query_images': query_images, #
+ 'query_labels': query_labels,
+ 'scan_id': scan_id,
+ 'z_id': z_id,
+ 'nframe': nframe,
+ }
+
+
+ def __len__(self):
+ """
+ copy-paste from basic naive dataset configuration
+ """
+ if self.fix_length != None:
+ assert self.fix_length >= len(self.actual_dataset)
+ return self.fix_length
+ else:
+ return len(self.actual_dataset)
+
+ def getMaskMedImg(self, label, class_id, class_ids):
+ """
+ Generate FG/BG mask from the segmentation mask
+
+ Args:
+ label: semantic mask
+ class_id: semantic class of interest
+ class_ids: all class id in this episode
+ """
+ fg_mask = torch.where(label == class_id,
+ torch.ones_like(label), torch.zeros_like(label))
+ bg_mask = torch.where(label != class_id,
+ torch.ones_like(label), torch.zeros_like(label))
+ for class_id in class_ids:
+ bg_mask[label == class_id] = 0
+
+ return {'fg_mask': fg_mask,
+ 'bg_mask': bg_mask}
diff --git a/dataloaders/ManualAnnoDatasetv2.py b/dataloaders/ManualAnnoDatasetv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8312d20514c5177d03e41ec05b3c8b8c89cf6e0
--- /dev/null
+++ b/dataloaders/ManualAnnoDatasetv2.py
@@ -0,0 +1,756 @@
+"""
+Manually labeled dataset
+TODO:
+1. Merge with superpixel dataset
+"""
+import glob
+import numpy as np
+import dataloaders.augutils as myaug
+import torch
+import random
+import os
+import copy
+import platform
+import json
+import re
+import cv2
+from dataloaders.common import BaseDataset, Subset, ValidationDataset
+# from common import BaseDataset, Subset
+from dataloaders.dataset_utils import*
+from pdb import set_trace
+from util.utils import CircularList
+from util.consts import IMG_SIZE
+
+MODE_DEFAULT = "default"
+MODE_FULL_SCAN = "full_scan"
+
+class ManualAnnoDataset(BaseDataset):
+ def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None, **kwargs):
+ """
+ Manually labeled dataset
+ Args:
+ which_dataset: name of the dataset to use
+ base_dir: directory of dataset
+ idx_split: index of data split as we will do cross validation
+ mode: 'train', 'val'.
+ transforms: data transform (augmentation) function
+ min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset
+ scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
+ tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
+ nsup: number of support scans
+ fix_length: fix the length of dataset
+ exclude_list: Labels to be excluded
+ extern_normalize_function: normalization function used for data pre-processing
+ """
+ super(ManualAnnoDataset, self).__init__(base_dir)
+ self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
+ self.sep = DATASET_INFO[which_dataset]['_SEP']
+ self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
+ self.image_size = image_size
+ self.transforms = transforms
+ self.is_train = True if mode == 'train' else False
+ self.phase = mode
+ self.fix_length = fix_length
+ self.all_label_names = self.label_name
+ self.nclass = len(self.label_name)
+ self.tile_z_dim = tile_z_dim
+ self.base_dir = base_dir
+ self.nsup = nsup
+ self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
+ self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds
+ if 'use_clahe' not in kwargs:
+ self.use_clahe = False
+ else:
+ self.use_clahe = kwargs['use_clahe']
+ if self.use_clahe:
+ self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7,7))
+
+ self.use_3_slices = kwargs["use_3_slices"] if 'use_3_slices' in kwargs else False
+ if self.use_3_slices:
+ self.tile_z_dim=1
+
+ self.get_item_mode = MODE_DEFAULT
+ if 'get_item_mode' in kwargs:
+ self.get_item_mode = kwargs['get_item_mode']
+
+ self.exclude_lbs = exclude_list
+ if len(exclude_list) > 0:
+ print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
+
+ self.idx_split = idx_split
+ self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
+ self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
+
+ self.scan_per_load = scan_per_load
+
+ self.info_by_scan = None
+ self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
+
+ if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset.
+ self.norm_func = extern_normalize_func
+ print(f'###### Dataset: using external normalization statistics ######')
+ else:
+ self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()])
+ print(f'###### Dataset: using normalization statistics calculated from loaded data ######')
+
+ if self.is_train:
+ if scan_per_load > 0: # buffer needed
+ self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
+ else: # load the entire set without a buffer
+ self.pid_curr_load = self.scan_ids
+ elif mode == 'val':
+ self.pid_curr_load = self.scan_ids
+ self.potential_support_sid = []
+ else:
+ raise Exception
+ self.actual_dataset = self.read_dataset()
+ self.size = len(self.actual_dataset)
+ self.overall_slice_by_cls = self.read_classfiles()
+ self.update_subclass_lookup()
+
+ def get_scanids(self, mode, idx_split):
+ val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
+ self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index
+ if mode == 'train':
+ return [ ii for ii in self.img_pids if ii not in val_ids ]
+ elif mode == 'val':
+ return val_ids
+
+ def reload_buffer(self):
+ """
+ Reload a portion of the entire dataset, if the dataset is too large
+ 1. delete original buffer
+ 2. update self.ids_this_batch
+ 3. update other internel variables like __len__
+ """
+ if self.scan_per_load <= 0:
+ print("We are not using the reload buffer, doing notiong")
+ return -1
+
+ del self.actual_dataset
+ del self.info_by_scan
+ self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
+ self.actual_dataset = self.read_dataset()
+ self.size = len(self.actual_dataset)
+ self.update_subclass_lookup()
+ print(f'Loader buffer reloaded with a new size of {self.size} slices')
+
+ def organize_sample_fids(self):
+ out_list = {}
+ for curr_id in self.scan_ids:
+ curr_dict = {}
+
+ _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
+ _lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
+
+ curr_dict["img_fid"] = _img_fid
+ curr_dict["lbs_fid"] = _lb_fid
+ out_list[str(curr_id)] = curr_dict
+ return out_list
+
+ def read_dataset(self):
+ """
+ Build index pointers to individual slices
+ Also keep a look-up table from scan_id, slice to index
+ """
+ out_list = []
+ self.scan_z_idx = {}
+ self.info_by_scan = {} # meta data of each scan
+ glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
+
+ for scan_id, itm in self.img_lb_fids.items():
+ if scan_id not in self.pid_curr_load:
+ continue
+
+ img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
+
+ img = img.transpose(1,2,0)
+
+ self.info_by_scan[scan_id] = _info
+
+ if self.use_clahe:
+ img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
+
+ img = np.float32(img)
+ img = self.norm_func(img)
+
+ self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
+
+ lb = read_nii_bysitk(itm["lbs_fid"])
+ lb = lb.transpose(1,2,0)
+
+ lb = np.float32(lb)
+
+ img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
+ lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
+
+ assert img.shape[-1] == lb.shape[-1]
+ base_idx = img.shape[-1] // 2 # index of the middle slice
+
+ # write the beginning frame
+ out_list.append( {"img": img[..., 0: 1],
+ "lb":lb[..., 0: 0 + 1],
+ "is_start": True,
+ "is_end": False,
+ "nframe": img.shape[-1],
+ "scan_id": scan_id,
+ "z_id":0})
+
+ self.scan_z_idx[scan_id][0] = glb_idx
+ glb_idx += 1
+
+ for ii in range(1, img.shape[-1] - 1):
+ out_list.append( {"img": img[..., ii: ii + 1],
+ "lb":lb[..., ii: ii + 1],
+ "is_start": False,
+ "is_end": False,
+ "nframe": -1,
+ "scan_id": scan_id,
+ "z_id": ii
+ })
+ self.scan_z_idx[scan_id][ii] = glb_idx
+ glb_idx += 1
+
+ ii += 1 # last frame, note the is_end flag
+ out_list.append( {"img": img[..., ii: ii + 1],
+ "lb":lb[..., ii: ii+ 1],
+ "is_start": False,
+ "is_end": True,
+ "nframe": -1,
+ "scan_id": scan_id,
+ "z_id": ii
+ })
+
+ self.scan_z_idx[scan_id][ii] = glb_idx
+ glb_idx += 1
+
+ return out_list
+
+ def read_classfiles(self):
+ with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
+ cls_map = json.load( fopen)
+ fopen.close()
+
+ with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
+ self.tp1_cls_map = json.load( fopen)
+ fopen.close()
+
+ return cls_map
+
+ def __getitem__(self, index):
+ if self.get_item_mode == MODE_DEFAULT:
+ return self.__getitem_default__(index)
+ elif self.get_item_mode == MODE_FULL_SCAN:
+ return self.__get_ct_scan___(index)
+ else:
+ raise NotImplementedError("Unknown mode")
+
+
+ def __get_ct_scan___(self, index):
+ scan_n = index % len(self.scan_z_idx)
+ scan_id = list(self.scan_z_idx.keys())[scan_n]
+ scan_slices = self.scan_z_idx[scan_id]
+
+ scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
+
+ scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
+
+ scan_imgs = np.float32(scan_imgs)
+ scan_lbs = np.float32(scan_lbs)
+
+ scan_imgs = torch.from_numpy(scan_imgs).unsqueeze(0)
+ scan_lbs = torch.from_numpy(scan_lbs)
+
+ if self.tile_z_dim:
+ scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1)
+ assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}'
+
+ # # reshape to C, D, H, W
+ # scan_imgs = scan_imgs.permute(1, 0, 2, 3)
+ # scan_lbs = scan_lbs.permute(1, 0, 2, 3)
+
+ sample = {"image": scan_imgs,
+ "label":scan_lbs,
+ "scan_id": scan_id,
+ }
+
+ return sample
+
+
+ def get_3_slice_adjacent_image(self, image_t, index):
+ curr_dict = self.actual_dataset[index]
+ prev_image = np.zeros_like(image_t)
+
+ if index > 1 and not curr_dict["is_start"]:
+ prev_dict = self.actual_dataset[index - 1]
+ prev_image = prev_dict["img"]
+
+ next_image = np.zeros_like(image_t)
+ if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
+ next_dict = self.actual_dataset[index + 1]
+ next_image = next_dict["img"]
+
+ image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
+
+ return image_t
+
+
+ def __getitem_default__(self, index):
+ index = index % len(self.actual_dataset)
+ curr_dict = self.actual_dataset[index]
+ if self.is_train:
+ if len(self.exclude_lbs) > 0:
+ for _ex_cls in self.exclude_lbs:
+ if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen
+ return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
+
+ comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 )
+ if self.transforms is not None:
+ img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False)
+ else:
+ raise Exception("No transform function is provided")
+
+ else:
+ img = curr_dict['img']
+ lb = curr_dict['lb']
+
+
+ img = np.float32(img)
+ lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
+ if self.use_3_slices:
+ img = self.get_3_slice_adjacent_image(img, index)
+
+ img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
+ lb = torch.from_numpy( lb)
+
+ if self.tile_z_dim:
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
+ assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
+
+ is_start = curr_dict["is_start"]
+ is_end = curr_dict["is_end"]
+ nframe = np.int32(curr_dict["nframe"])
+ scan_id = curr_dict["scan_id"]
+ z_id = curr_dict["z_id"]
+
+ sample = {"image": img,
+ "label":lb,
+ "is_start": is_start,
+ "is_end": is_end,
+ "nframe": nframe,
+ "scan_id": scan_id,
+ "z_id": z_id
+ }
+ # Add auxiliary attributes
+ if self.aux_attrib is not None:
+ for key_prefix in self.aux_attrib:
+ # Process the data sample, create new attributes and save them in a dictionary
+ aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
+ for key_suffix in aux_attrib_val:
+ # one function may create multiple attributes, so we need suffix to distinguish them
+ sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
+
+ return sample
+
+ def __len__(self):
+ """
+ copy-paste from basic naive dataset configuration
+ """
+ if self.get_item_mode == MODE_FULL_SCAN:
+ return len(self.scan_z_idx)
+
+ if self.fix_length != None:
+ assert self.fix_length >= len(self.actual_dataset)
+ return self.fix_length
+ else:
+ return len(self.actual_dataset)
+
+ def update_subclass_lookup(self):
+ """
+ Updating the class-slice indexing list
+ Args:
+ [internal] overall_slice_by_cls:
+ {
+ class1: {pid1: [slice1, slice2, ....],
+ pid2: [slice1, slice2]},
+ ...}
+ class2:
+ ...
+ }
+ out[internal]:
+ {
+ class1: [ idx1, idx2, ... ],
+ class2: [ idx1, idx2, ... ],
+ ...
+ }
+
+ """
+ # delete previous ones if any
+ assert self.overall_slice_by_cls is not None
+
+ if not hasattr(self, 'idx_by_class'):
+ self.idx_by_class = {}
+ # filter the new one given the actual list
+ for cls in self.label_name:
+ if cls not in self.idx_by_class.keys():
+ self.idx_by_class[cls] = []
+ else:
+ del self.idx_by_class[cls][:]
+ for cls, dict_by_pid in self.overall_slice_by_cls.items():
+ for pid, slice_list in dict_by_pid.items():
+ if pid not in self.pid_curr_load:
+ continue
+ self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ]
+ print("###### index-by-class table has been reloaded ######")
+
+ def getMaskMedImg(self, label, class_id, class_ids):
+ """
+ Generate FG/BG mask from the segmentation mask. Used when getting the support
+ """
+ # Dense Mask
+ fg_mask = torch.where(label == class_id,
+ torch.ones_like(label), torch.zeros_like(label))
+ bg_mask = torch.where(label != class_id,
+ torch.ones_like(label), torch.zeros_like(label))
+ for class_id in class_ids:
+ bg_mask[label == class_id] = 0
+
+ return {'fg_mask': fg_mask,
+ 'bg_mask': bg_mask}
+
+ def subsets(self, sub_args_lst=None):
+ """
+ Override base-class subset method
+ Create subsets by scan_ids
+
+ output: list [[] , ]
+ """
+
+ if sub_args_lst is not None:
+ subsets = []
+ ii = 0
+ for cls_name, index_list in self.idx_by_class.items():
+ subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) )
+ ii += 1
+ else:
+ subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()]
+ return subsets
+
+ def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int):
+ """
+ getting (probably multi-shot) support set for evaluation
+ sample from 50% (1shot) or 20 35 50 65 80 (5shot)
+ Args:
+ curr_cls: current class to segment, starts from 1
+ class_idx: a list of all foreground class in nways, starts from 1
+ npart: how may chunks used to split the support
+ scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan
+ being served as support, in self.pid_curr_load
+ """
+ assert npart % 2 == 1
+ assert curr_class != 0; assert 0 not in class_idx
+ # assert not self.is_train
+
+ self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
+ # print(f'###### Using {len(scan_idx)} shot evaluation!')
+
+ if npart == 1:
+ pcts = [0.5]
+ else:
+ half_part = 1 / (npart * 2)
+ part_interval = (1.0 - 1.0 / npart) / (npart - 1)
+ pcts = [ half_part + part_interval * ii for ii in range(npart) ]
+
+ # print(f'###### Parts percentage: {pcts} ######')
+
+ # norm_func = get_normalize_op(modality='MR', fids=None)
+ out_buffer = [] # [{scanid, img, lb}]
+ for _part in range(npart):
+ concat_buffer = [] # for each fold do a concat in image and mask in batch dimension
+ for scan_order in scan_idx:
+ _scan_id = self.pid_curr_load[ scan_order ]
+ print(f'Using scan {_scan_id} as support!')
+
+ # for _pc in pcts:
+ _zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices
+ _zid = _zlist[int(pcts[_part] * len(_zlist))]
+ _glb_idx = self.scan_z_idx[_scan_id][_zid]
+
+ # almost copy-paste __getitem__ but no augmentation
+ curr_dict = self.actual_dataset[_glb_idx]
+ img = curr_dict['img']
+ lb = curr_dict['lb']
+
+ if self.use_3_slices:
+ prev_image = np.zeros_like(img)
+ if _glb_idx > 1 and not curr_dict["is_start"]:
+ prev_dict = self.actual_dataset[_glb_idx - 1]
+ prev_image = prev_dict["img"]
+
+ next_image = np.zeros_like(img)
+ if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
+ next_dict = self.actual_dataset[_glb_idx + 1]
+ next_image = next_dict["img"]
+
+ img = np.concatenate([prev_image, img, next_image], axis=-1)
+
+ img = np.float32(img)
+ lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
+
+ img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
+ lb = torch.from_numpy( lb )
+
+ if self.tile_z_dim:
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
+ assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
+
+ is_start = curr_dict["is_start"]
+ is_end = curr_dict["is_end"]
+ nframe = np.int32(curr_dict["nframe"])
+ scan_id = curr_dict["scan_id"]
+ z_id = curr_dict["z_id"]
+
+ sample = {"image": img,
+ "label":lb,
+ "is_start": is_start,
+ "inst": None,
+ "scribble": None,
+ "is_end": is_end,
+ "nframe": nframe,
+ "scan_id": scan_id,
+ "z_id": z_id
+ }
+
+ concat_buffer.append(sample)
+ out_buffer.append({
+ "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0),
+ "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0),
+
+ })
+
+ # do the concat, and add to output_buffer
+
+ # post-processing, including keeping the foreground and suppressing background.
+ support_images = []
+ support_mask = []
+ support_class = []
+ for itm in out_buffer:
+ support_images.append(itm["image"])
+ support_class.append(curr_class)
+ support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx ))
+
+ return {'class_ids': [support_class],
+ 'support_images': [support_images], #
+ 'support_mask': [support_mask],
+ }
+
+ def get_support_scan(self, curr_class: int, class_idx: list, scan_idx: list):
+ self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
+ # print(f'###### Using {len(scan_idx)} shot evaluation!')
+ scan_slices = self.scan_z_idx[self.potential_support_sid[0]]
+ scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
+
+ scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
+ # binarize the labels
+ scan_lbs[scan_lbs != curr_class] = 0
+ scan_lbs[scan_lbs == curr_class] = 1
+
+ scan_imgs = torch.from_numpy(np.float32(scan_imgs)).unsqueeze(0)
+ scan_lbs = torch.from_numpy(np.float32(scan_lbs))
+
+ if self.tile_z_dim:
+ scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1)
+ assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}'
+
+ # reshape to C, D, H, W
+ sample = {"scan": scan_imgs,
+ "labels":scan_lbs,
+ }
+
+ return sample
+
+
+ def get_support_multiple_classes(self, classes: list, scan_idx: list, npart: int, use_3_slices=False):
+ """
+ getting (probably multi-shot) support set for evaluation
+ sample from 50% (1shot) or 20 35 50 65 80 (5shot)
+ Args:
+ curr_cls: current class to segment, starts from 1
+ class_idx: a list of all foreground class in nways, starts from 1
+ npart: how may chunks used to split the support
+ scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan
+ being served as support, in self.pid_curr_load
+ """
+ assert npart % 2 == 1
+ # assert curr_class != 0; assert 0 not in class_idx
+ # assert not self.is_train
+
+ self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
+ # print(f'###### Using {len(scan_idx)} shot evaluation!')
+
+ if npart == 1:
+ pcts = [0.5]
+ else:
+ half_part = 1 / (npart * 2)
+ part_interval = (1.0 - 1.0 / npart) / (npart - 1)
+ pcts = [ half_part + part_interval * ii for ii in range(npart) ]
+
+ # print(f'###### Parts percentage: {pcts} ######')
+
+ out_buffer = [] # [{scanid, img, lb}]
+ for _part in range(npart):
+ concat_buffer = [] # for each fold do a concat in image and mask in batch dimension
+ for scan_order in scan_idx:
+ _scan_id = self.pid_curr_load[ scan_order ]
+ print(f'Using scan {_scan_id} as support!')
+
+ # for _pc in pcts:
+ zlist = []
+ for curr_class in classes:
+ zlist.append(self.tp1_cls_map[self.label_name[curr_class]][_scan_id]) # list of indices
+ # merge all the lists in zlist and keep only the unique elements
+ # _zlist = sorted(list(set([item for sublist in zlist for item in sublist])))
+ # take only the indices that appear in all of the sublist
+ _zlist = sorted(list(set.intersection(*map(set, zlist))))
+ _zid = _zlist[int(pcts[_part] * len(_zlist))]
+ _glb_idx = self.scan_z_idx[_scan_id][_zid]
+
+ # almost copy-paste __getitem__ but no augmentation
+ curr_dict = self.actual_dataset[_glb_idx]
+ img = curr_dict['img']
+ lb = curr_dict['lb']
+
+ if use_3_slices:
+ prev_image = np.zeros_like(img)
+ if _glb_idx > 1 and not curr_dict["is_start"]:
+ prev_dict = self.actual_dataset[_glb_idx - 1]
+ assert prev_dict["scan_id"] == curr_dict["scan_id"]
+ assert prev_dict["z_id"] == curr_dict["z_id"] - 1
+ prev_image = prev_dict["img"]
+
+ next_image = np.zeros_like(img)
+ if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
+ next_dict = self.actual_dataset[_glb_idx + 1]
+ assert next_dict["scan_id"] == curr_dict["scan_id"]
+ assert next_dict["z_id"] == curr_dict["z_id"] + 1
+ next_image = next_dict["img"]
+
+ img = np.concatenate([prev_image, img, next_image], axis=-1)
+
+ img = np.float32(img)
+ lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
+ # zero all labels that are not in the classes arg
+ mask = np.zeros_like(lb)
+ for cls in classes:
+ mask[lb == cls] = 1
+ lb[~mask.astype(np.bool)] = 0
+
+ img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
+ lb = torch.from_numpy( lb )
+
+ if self.tile_z_dim:
+ img = img.repeat( [ self.tile_z_dim, 1, 1] )
+ assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
+
+ is_start = curr_dict["is_start"]
+ is_end = curr_dict["is_end"]
+ nframe = np.int32(curr_dict["nframe"])
+ scan_id = curr_dict["scan_id"]
+ z_id = curr_dict["z_id"]
+
+ sample = {"image": img,
+ "label":lb,
+ "is_start": is_start,
+ "inst": None,
+ "scribble": None,
+ "is_end": is_end,
+ "nframe": nframe,
+ "scan_id": scan_id,
+ "z_id": z_id
+ }
+
+ concat_buffer.append(sample)
+ out_buffer.append({
+ "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0),
+ "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0),
+
+ })
+
+ # do the concat, and add to output_buffer
+
+ # post-processing, including keeping the foreground and suppressing background.
+ support_images = []
+ support_mask = []
+ support_class = []
+ for itm in out_buffer:
+ support_images.append(itm["image"])
+ support_class.append(curr_class)
+ # support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx ))
+ support_mask.append(itm["label"])
+
+ return {'class_ids': [support_class],
+ 'support_images': [support_images], #
+ 'support_mask': [support_mask],
+ 'scan_id': scan_id
+ }
+
+def get_nii_dataset(config, image_size, **kwargs):
+ print(f"Check config: {config}")
+ organ_mapping = {
+ "sabs":{
+ "rk": 2,
+ "lk": 3,
+ "liver": 6,
+ "spleen": 1
+ },
+ "chaost2":{
+ "liver": 1,
+ "rk": 2,
+ "lk": 3,
+ "spleen": 4
+ }}
+
+ transforms = None
+ data_name = config['dataset']
+ if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672':
+ baseset_name = 'SABS'
+ max_label = 13
+ modality="CT"
+ elif data_name == 'C0_Superpix':
+ raise NotImplementedError
+ baseset_name = 'C0'
+ max_label = 3
+ elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672':
+ baseset_name = 'CHAOST2'
+ max_label = 4
+ modality="MR"
+ elif 'lits' in data_name.lower():
+ baseset_name = 'LITS17'
+ max_label = 4
+ else:
+ raise ValueError(f'Dataset: {data_name} not found')
+
+ # norm_func = get_normalize_op(modality=modality, fids=None) # TODO add global statistics
+ # norm_func = None
+
+ test_label = organ_mapping[baseset_name.lower()][config["curr_cls"]]
+ base_dir = config['path'][data_name]['data_dir']
+ testdataset = ManualAnnoDataset(which_dataset=baseset_name,
+ base_dir=base_dir,
+ idx_split = config['eval_fold'],
+ mode = 'val',
+ scan_per_load = 1,
+ transforms=transforms,
+ min_fg=1,
+ nsup = config["task"]["n_shots"],
+ fix_length=None,
+ image_size=image_size,
+ # extern_normalize_func=norm_func
+ **kwargs)
+
+ testdataset = ValidationDataset(testdataset, test_classes = [test_label], npart = config["task"]["npart"])
+ testdataset.set_curr_cls(test_label)
+
+ traindataset = None # TODO make this the support set later
+
+ return traindataset, testdataset
\ No newline at end of file
diff --git a/dataloaders/PolypDataset.py b/dataloaders/PolypDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb827d1618887495bf6bbbed29a202c0a085b3b
--- /dev/null
+++ b/dataloaders/PolypDataset.py
@@ -0,0 +1,548 @@
+"""
+Copied from https://github.com/talshaharabany/AutoSAM
+"""
+
+import os
+from PIL import Image
+import torch.utils.data as data
+import torchvision.transforms as transforms
+import numpy as np
+import random
+import torch
+from dataloaders.PolypTransforms import get_polyp_transform
+import cv2
+KVASIR = "Kvasir"
+CLINIC_DB = "CVC-ClinicDB"
+COLON_DB = "CVC-ColonDB"
+ETIS_DB = "ETIS-LaribPolypDB"
+CVC300 = "CVC-300"
+
+DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB)
+EXCLUDE_DS = (CVC300, )
+
+
+def create_suppport_set_for_polyps(n_support=10):
+ """
+ create a text file contating n_support_images for each dataset
+ """
+ root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset"
+ supp_images = []
+ supp_masks = []
+
+ image_dir = os.path.join(root_dir, "images")
+ mask_dir = os.path.join(root_dir, "masks")
+ # randonly sample n_support images and masks
+ image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(
+ image_dir) if f.endswith('.jpg') or f.endswith('.png')])
+ mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(
+ mask_dir) if f.endswith('.png')])
+
+ while len(supp_images) < n_support:
+ index = random.randint(0, len(image_paths) - 1)
+ # check that the index is not already in the support set
+ if image_paths[index] in supp_images:
+ continue
+ supp_images.append(image_paths[index])
+ supp_masks.append(mask_paths[index])
+
+ with open(os.path.join(root_dir, "support.txt"), 'w') as file:
+ for image_path, mask_path in zip(supp_images, supp_masks):
+ file.write(f"{image_path} {mask_path}\n")
+
+def create_train_val_test_split_for_polyps():
+ root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/"
+ # for each subdir in root_dir, create a split file
+ num_train_images_per_dataset = {
+ "CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0}
+
+ num_test_images_per_dataset = {
+ "CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380}
+
+ for subdir in os.listdir(root_dir):
+ subdir_path = os.path.join(root_dir, subdir)
+ if os.path.isdir(subdir_path):
+ split_file = os.path.join(subdir_path, "split.txt")
+ image_dir = os.path.join(subdir_path, "images")
+ create_train_val_test_split(
+ image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir])
+
+
+def create_train_val_test_split(root, split_file, train_number=100, test_number=20):
+ """
+ Create a train, val, test split file for a dataset
+ root: root directory of dataset
+ split_file: name of split file to create
+ train_ratio: ratio of train set
+ val_ratio: ratio of val set
+ test_ratio: ratio of test set
+ """
+ # Get all files in root directory
+ files = os.listdir(root)
+ # Filter out non-image files, remove suffix
+ files = [f.split('.')[0]
+ for f in files if f.endswith('.jpg') or f.endswith('.png')]
+ # Shuffle files
+ random.shuffle(files)
+
+ # Calculate number of files for each split
+ num_files = len(files)
+ num_train = train_number
+ num_test = test_number
+ num_val = num_files - num_train - num_test
+ print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")
+ # Create splits
+ train = files[:num_train]
+ val = files[num_train:num_train + num_val]
+ test = files[num_train + num_val:]
+
+ # Write splits to file
+ with open(split_file, 'w') as file:
+ file.write("train\n")
+ for f in train:
+ file.write(f + "\n")
+ file.write("val\n")
+ for f in val:
+ file.write(f + "\n")
+ file.write("test\n")
+ for f in test:
+ file.write(f + "\n")
+
+
+class PolypDataset(data.Dataset):
+ """
+ dataloader for polyp segmentation tasks
+ """
+
+ def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None):
+ self.trainsize = trainsize
+ self.augmentations = augmentations
+ self.datasets = datasets
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ self.image_size = image_size
+ if image_root is not None and gt_root is not None:
+ self.images = [
+ os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
+ self.gts = [
+ os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')]
+ # also look in subdirectories
+ for subdir in os.listdir(image_root):
+ # if not dir, continue
+ if not os.path.isdir(os.path.join(image_root, subdir)):
+ continue
+ subdir_image_root = os.path.join(image_root, subdir)
+ subdir_gt_root = os.path.join(gt_root, subdir)
+ self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir(
+ subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')])
+ self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir(
+ subdir_gt_root) if f.endswith('.png')])
+
+ else:
+ self.images, self.gts = self.get_image_gt_pairs(
+ root, split="train" if train else "test", datasets=self.datasets)
+ self.images = sorted(self.images)
+ self.gts = sorted(self.gts)
+ if not 'VPS' in root:
+ self.filter_files_and_get_ds_mean_and_std()
+ if ds_mean is not None and ds_std is not None:
+ self.mean, self.std = ds_mean, ds_std
+ self.size = len(self.images)
+ self.train = train
+ self.sam_trans = sam_trans
+ if self.sam_trans is not None:
+ # sam trans takes care of norm
+ self.mean, self.std = 0 , 1
+
+ def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS):
+ """
+ for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file
+ dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders
+ split: train, val, or test
+ """
+ image_paths = []
+ gt_paths = []
+ for folder in os.listdir(dir_root):
+ if folder not in datasets:
+ continue
+ split_file = os.path.join(dir_root, folder, "split.txt")
+ if os.path.isfile(split_file):
+ image_root = os.path.join(dir_root, folder, "images")
+ gt_root = os.path.join(dir_root, folder, "masks")
+ image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file(
+ image_root, gt_root, split_file, split=split)
+ image_paths.extend(image_paths_tmp)
+ gt_paths.extend(gt_paths_tmp)
+ else:
+ print(
+ f"No split.txt file found in {os.path.join(dir_root, folder)}")
+
+ return image_paths, gt_paths
+
+ def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"):
+ """
+ image_root: root directory of images
+ gt_root: root directory of ground truth
+ text_file: text file containing train, val, test split with the following format:
+ train:
+ image1
+ image2
+ ...
+ val:
+ image1
+ image2
+ ...
+ test:
+ image1
+ image2
+ ...
+
+ split: train, val, or test
+ """
+ # Initialize a dictionary to hold file names for each split
+ splits = {"train": [], "val": [], "test": []}
+ current_split = None
+
+ # Read the text file and categorize file names under each split
+ with open(text_file, 'r') as file:
+ for line in file:
+ line = line.strip()
+ if line in splits:
+ current_split = line
+ elif line and current_split:
+ splits[current_split].append(line)
+
+ # Get the file names for the requested split
+ file_names = splits[split]
+
+ # Create image-ground truth pairs
+ image_paths = []
+ gt_paths = []
+ for name in file_names:
+ image_path = os.path.join(image_root, name + '.png')
+ gt_path = os.path.join(gt_root, name + '.png')
+ image_paths.append(image_path)
+ gt_paths.append(gt_path)
+
+ return image_paths, gt_paths
+
+ def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1):
+ support_images = []
+ support_labels = []
+ # get all images and masks
+ support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir(
+ support_image_dir) if f.endswith('.jpg') or f.endswith('.png')])
+ support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir(
+ support_mask_dir) if f.endswith('.png')])
+ # sample n_support images and masks
+ for i in range(n_support):
+ index = random.randint(0, len(support_image_paths) - 1)
+ support_img = self.cv2_loader(
+ support_image_paths[index], is_mask=False)
+ support_mask = self.cv2_loader(
+ support_mask_paths[index], is_mask=True)
+ support_images.append(support_img)
+ support_labels.append(support_mask)
+
+ if self.augmentations:
+ support_images = [self.augmentations(
+ img, mask)[0] for img, mask in zip(support_images, support_labels)]
+ support_labels = [self.augmentations(
+ img, mask)[1] for img, mask in zip(support_images, support_labels)]
+
+ support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images]
+
+ if self.sam_trans is not None:
+ support_images = [self.sam_trans.preprocess(
+ img).squeeze(0) for img in support_images]
+ support_labels = [self.sam_trans.preprocess(
+ mask) for mask in support_labels]
+ else:
+ image_size = self.image_size
+ support_images = [torch.nn.functional.interpolate(img.unsqueeze(
+ 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images]
+ support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(
+ 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels]
+
+ return torch.stack(support_images), torch.stack(support_labels)
+
+ def get_support_from_text_file(self, text_file, n_support=1):
+ """
+ each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path
+ """
+ support_images = []
+ support_labels = []
+ with open(text_file, 'r') as file:
+ for line in file:
+ image_path, mask_path = line.strip().split()
+ support_images.append(image_path)
+ support_labels.append(mask_path)
+
+ # indices = random.choices(range(len(support_images)), k=n_support)
+ if n_support > len(support_images):
+ raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})")
+
+ n_support_images = support_images[:n_support]
+ n_support_labels = support_labels[:n_support]
+
+ return n_support_images, n_support_labels
+
+ def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None):
+ """
+ Get support set from specified directories, text file or from the dataset itself
+ """
+ if support_image_dir is not None and support_mask_dir:
+ return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support)
+ elif text_file is not None:
+ support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support)
+ else:
+ # randomly sample n_support images and masks from the dataset
+ indices = random.choices(range(self.size), k=n_support)
+ # indices = list(range(n_support))
+ print(f"support indices:{indices}")
+ support_image_paths = [self.images[index] for index in indices]
+ support_gt_paths = [self.gts[index] for index in indices]
+
+ support_images = []
+ support_gts = []
+
+ for image_path, gt_path in zip(support_image_paths, support_gt_paths):
+ support_img = self.cv2_loader(image_path, is_mask=False)
+ support_mask = self.cv2_loader(gt_path, is_mask=True)
+ out = self.process_image_gt(support_img, support_mask)
+ support_images.append(out['image'].unsqueeze(0))
+ support_gts.append(out['label'].unsqueeze(0))
+ if len(support_images) >= n_support:
+ break
+ return support_images, support_gts, out['case']
+ # return torch.stack(support_images), torch.stack(support_gts), out['case']
+
+ def process_image_gt(self, image, gt, dataset=""):
+ """
+ image and gt are expected to be output from self.cv2_loader
+ """
+ original_size = tuple(image.shape[-2:])
+ if self.augmentations:
+ image, mask = self.augmentations(image, gt)
+
+ if self.sam_trans:
+ image, mask = self.sam_trans.apply_image_torch(
+ image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask)
+ elif image.max() <= 255 and image.min() >= 0:
+ image = (image - self.mean) / self.std
+ mask[mask > 0.5] = 1
+ mask[mask <= 0.5] = 0
+ # image_size = tuple(img.shape[-2:])
+
+ image_size = self.image_size
+ if self.sam_trans is None:
+ image = torch.nn.functional.interpolate(image.unsqueeze(
+ 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0)
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(
+ 0), size=image_size, mode='nearest').squeeze(0).squeeze(0)
+ # img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse
+
+ return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image,
+ 'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask,
+ 'original_size': torch.Tensor(original_size),
+ 'image_size': torch.Tensor(image_size),
+ 'case': dataset} # case to be compatible with polyp video dataset
+
+ def get_dataset_name_from_path(self, path):
+ for dataset in self.datasets:
+ if dataset in path:
+ return dataset
+ return ""
+
+ def __getitem__(self, index):
+ image = self.cv2_loader(self.images[index], is_mask=False)
+ gt = self.cv2_loader(self.gts[index], is_mask=True)
+ dataset = self.get_dataset_name_from_path(self.images[index])
+ return self.process_image_gt(image, gt, dataset)
+
+ def filter_files_and_get_ds_mean_and_std(self):
+ assert len(self.images) == len(self.gts)
+ images = []
+ gts = []
+ ds_mean = 0
+ ds_std = 0
+ for img_path, gt_path in zip(self.images, self.gts):
+ if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]):
+ continue
+ img = Image.open(img_path)
+ gt = Image.open(gt_path)
+ if img.size == gt.size:
+ images.append(img_path)
+ gts.append(gt_path)
+ ds_mean += np.array(img).mean()
+ ds_std += np.array(img).std()
+ self.images = images
+ self.gts = gts
+ self.mean = ds_mean / len(self.images)
+ self.std = ds_std / len(self.images)
+
+ def rgb_loader(self, path):
+ with open(path, 'rb') as f:
+ img = Image.open(f)
+ return img.convert('RGB')
+
+ def binary_loader(self, path):
+ # with open(path, 'rb') as f:
+ # img = Image.open(f)
+ # return img.convert('1')
+ img = cv2.imread(path, 0)
+ return img
+
+ def cv2_loader(self, path, is_mask):
+ if is_mask:
+ img = cv2.imread(path, 0)
+ img[img > 0] = 1
+ else:
+ img = cv2.cvtColor(cv2.imread(
+ path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
+ return img
+
+ def resize(self, img, gt):
+ assert img.size == gt.size
+ w, h = img.size
+ if h < self.trainsize or w < self.trainsize:
+ h = max(h, self.trainsize)
+ w = max(w, self.trainsize)
+ return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
+ else:
+ return img, gt
+
+ def __len__(self):
+ # return 32
+ return self.size
+
+
+class SuperpixPolypDataset(PolypDataset):
+ def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None):
+ self.trainsize = trainsize
+ self.augmentations = augmentations
+ self.datasets = datasets
+ self.image_size = image_size
+ # print(self.augmentations)
+ if image_root is not None and gt_root is not None:
+ self.images = [
+ os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
+ self.gts = [
+ os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f]
+ # also look in subdirectories
+ for subdir in os.listdir(image_root):
+ # if not dir, continue
+ if not os.path.isdir(os.path.join(image_root, subdir)):
+ continue
+ subdir_image_root = os.path.join(image_root, subdir)
+ subdir_gt_root = os.path.join(gt_root, subdir)
+ self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir(
+ subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')])
+ self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir(
+ subdir_gt_root) if f.endswith('.png')])
+
+ else:
+ self.images, self.gts = self.get_image_gt_pairs(
+ root, split="train" if train else "test", datasets=self.datasets)
+ self.images = sorted(self.images)
+ self.gts = sorted(self.gts)
+ if not 'VPS' in root:
+ self.filter_files_and_get_ds_mean_and_std()
+ if ds_mean is not None and ds_std is not None:
+ self.mean, self.std = ds_mean, ds_std
+ self.size = len(self.images)
+ self.train = train
+ self.sam_trans = sam_trans
+ if self.sam_trans is not None:
+ # sam trans takes care of norm
+ self.mean, self.std = 0 , 1
+
+
+ def __getitem__(self, index):
+ image = self.cv2_loader(self.images[index], is_mask=False)
+ gt = self.cv2_loader(self.gts[index], is_mask=False)
+ gt = gt[:, :, 0]
+ fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_')
+ fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png')
+ fg = self.cv2_loader(fgpath, is_mask=True)
+ dataset = self.get_dataset_name_from_path(self.images[index])
+
+ # randomly choose a superpixels from the gt
+ gt[1-fg] = 0
+ sp_id = random.choice(np.unique(gt)[1:])
+ sp = (gt == sp_id).astype(np.uint8)
+
+
+ out = self.process_image_gt(image, gt, dataset)
+ support_image, support_sp, dataset = out["image"], out["label"], out["case"]
+
+ out = self.process_image_gt(image, sp, dataset)
+ query_image, query_sp, dataset = out["image"], out["label"], out["case"]
+
+ # TODO tile the masks to have 3 channels?
+
+ support_bg_mask = 1 - support_sp
+ support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask}
+
+ batch = {"support_images" : [[support_image]],
+ "support_mask" : [[support_masks]],
+ "query_images" : [query_image],
+ "query_labels" : [query_sp],
+ "scan_id" : [dataset]
+ }
+
+ return batch
+
+
+def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None):
+ transform_train, transform_test = get_polyp_transform()
+ image_root = './data/PolypDataset/TrainDataset/images/'
+ gt_root = './data/PolypDataset/TrainDataset/superpixels/'
+ ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root,
+ augmentations=transform_train,
+ sam_trans=sam_trans,
+ image_size=image_size)
+
+ return ds_train
+
+def get_polyp_dataset(image_size, sam_trans=None):
+ transform_train, transform_test = get_polyp_transform()
+ image_root = './data/PolypDataset/TrainDataset/images/'
+ gt_root = './data/PolypDataset/TrainDataset/masks/'
+ ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root,
+ augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size)
+ image_root = './data/PolypDataset/TestDataset/test/images/'
+ gt_root = './data/PolypDataset/TestDataset/test/masks/'
+ ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False,
+ augmentations=transform_test, sam_trans=sam_trans, image_size=image_size)
+ return ds_train, ds_test
+
+
+def get_tests_polyp_dataset(sam_trans):
+ transform_train, transform_test = get_polyp_transform()
+
+ image_root = './data/polyp/TestDataset/Kvasir/images/'
+ gt_root = './data/polyp/TestDataset/Kvasir/masks/'
+ ds_Kvasir = PolypDataset(
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
+
+ image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/'
+ gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/'
+ ds_ClinicDB = PolypDataset(
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
+
+ image_root = './data/polyp/TestDataset/CVC-ColonDB/images/'
+ gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/'
+ ds_ColonDB = PolypDataset(
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
+
+ image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/'
+ gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/'
+ ds_ETIS = PolypDataset(
+ image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
+
+ return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS
+
+
+if __name__ == '__main__':
+ # create_train_val_test_split_for_polyps()
+ create_suppport_set_for_polyps()
diff --git a/dataloaders/PolypTransforms.py b/dataloaders/PolypTransforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e3981d098f0f296c3dc1d004aadcb017c3464db
--- /dev/null
+++ b/dataloaders/PolypTransforms.py
@@ -0,0 +1,626 @@
+from __future__ import division
+import torch
+import math
+import sys
+import random
+from PIL import Image
+
+try:
+ import accimage
+except ImportError:
+ accimage = None
+import numpy as np
+import numbers
+import types
+import collections
+import warnings
+
+from torchvision.transforms import functional as F
+
+if sys.version_info < (3, 3):
+ Sequence = collections.Sequence
+ Iterable = collections.Iterable
+else:
+ Sequence = collections.abc.Sequence
+ Iterable = collections.abc.Iterable
+
+__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
+ "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
+ "RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop",
+ "ColorJitter", "RandomRotation", "RandomAffine",
+ "RandomPerspective"]
+
+_pil_interpolation_to_str = {
+ Image.NEAREST: 'PIL.Image.NEAREST',
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
+ Image.HAMMING: 'PIL.Image.HAMMING',
+ Image.BOX: 'PIL.Image.BOX',
+}
+
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img, mask):
+ for t in self.transforms:
+ img, mask = t(img, mask)
+ return img, mask
+
+
+class ToTensor(object):
+ def __call__(self, img, mask):
+ # return F.to_tensor(img), F.to_tensor(mask)
+ img = np.array(img)
+ img = torch.from_numpy(img).permute(2, 0, 1).float() # TODO add division by 255 to match torch.ToTensor()?
+ mask = torch.from_numpy(np.array(mask)).float()
+ return img, mask
+
+
+class ToPILImage(object):
+ def __init__(self, mode=None):
+ self.mode = mode
+
+ def __call__(self, img, mask):
+ return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode)
+
+
+class Normalize(object):
+ def __init__(self, mean, std, inplace=False):
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+ def __call__(self, img, mask):
+ return F.normalize(img, self.mean, self.std, self.inplace), mask
+
+
+class Resize(object):
+ def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
+ assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
+ self.size = size
+ self.interpolation = interpolation
+ self.do_mask = do_mask
+
+ def __call__(self, img, mask):
+ if self.do_mask:
+ return F.resize(img, self.size, Image.BICUBIC), F.resize(mask, self.size, Image.BICUBIC)
+ else:
+ return F.resize(img, self.size, Image.BICUBIC), mask
+
+
+class CenterCrop(object):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img, mask):
+ return F.center_crop(img, self.size), F.center_crop(mask, self.size)
+
+
+class Pad(object):
+ def __init__(self, padding, fill=0, padding_mode='constant'):
+ assert isinstance(padding, (numbers.Number, tuple))
+ assert isinstance(fill, (numbers.Number, str, tuple))
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+ if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
+ "{} element tuple".format(len(padding)))
+
+ self.padding = padding
+ self.fill = fill
+ self.padding_mode = padding_mode
+
+ def __call__(self, img, mask):
+ return F.pad(img, self.padding, self.fill, self.padding_mode), \
+ F.pad(mask, self.padding, self.fill, self.padding_mode)
+
+
+class Lambda(object):
+ def __init__(self, lambd):
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
+ self.lambd = lambd
+
+ def __call__(self, img, mask):
+ return self.lambd(img), self.lambd(mask)
+
+
+class Lambda_image(object):
+ def __init__(self, lambd):
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
+ self.lambd = lambd
+
+ def __call__(self, img, mask):
+ return self.lambd(img), mask
+
+
+class RandomTransforms(object):
+ def __init__(self, transforms):
+ assert isinstance(transforms, (list, tuple))
+ self.transforms = transforms
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError()
+
+
+class RandomApply(RandomTransforms):
+ def __init__(self, transforms, p=0.5):
+ super(RandomApply, self).__init__(transforms)
+ self.p = p
+
+ def __call__(self, img, mask):
+ if self.p < random.random():
+ return img, mask
+ for t in self.transforms:
+ img, mask = t(img, mask)
+ return img, mask
+
+
+class RandomOrder(RandomTransforms):
+ def __call__(self, img, mask):
+ order = list(range(len(self.transforms)))
+ random.shuffle(order)
+ for i in order:
+ img, mask = self.transforms[i](img, mask)
+ return img, mask
+
+
+class RandomChoice(RandomTransforms):
+ def __call__(self, img, mask):
+ t = random.choice(self.transforms)
+ return t(img, mask)
+
+
+class RandomCrop(object):
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+ self.pad_if_needed = pad_if_needed
+ self.fill = fill
+ self.padding_mode = padding_mode
+
+ @staticmethod
+ def get_params(img, output_size):
+ w, h = img.size
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = random.randint(0, h - th)
+ j = random.randint(0, w - tw)
+ return i, j, th, tw
+
+ def __call__(self, img, mask):
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ # pad the width if needed
+ if self.pad_if_needed and img.size[0] < self.size[1]:
+ img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and img.size[1] < self.size[0]:
+ img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+
+ return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, mask):
+ if random.random() < self.p:
+ return F.hflip(img), F.hflip(mask)
+ return img, mask
+
+
+class RandomVerticalFlip(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, mask):
+ if random.random() < self.p:
+ return F.vflip(img), F.vflip(mask)
+ return img, mask
+
+
+class RandomPerspective(object):
+ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
+ self.p = p
+ self.interpolation = interpolation
+ self.distortion_scale = distortion_scale
+
+ def __call__(self, img, mask):
+ if not F._is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ if random.random() < self.p:
+ width, height = img.size
+ startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
+ return F.perspective(img, startpoints, endpoints, self.interpolation), \
+ F.perspective(mask, startpoints, endpoints, Image.NEAREST)
+ return img, mask
+
+ @staticmethod
+ def get_params(width, height, distortion_scale):
+ half_height = int(height / 2)
+ half_width = int(width / 2)
+ topleft = (random.randint(0, int(distortion_scale * half_width)),
+ random.randint(0, int(distortion_scale * half_height)))
+ topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
+ random.randint(0, int(distortion_scale * half_height)))
+ botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
+ random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
+ botleft = (random.randint(0, int(distortion_scale * half_width)),
+ random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
+ endpoints = [topleft, topright, botright, botleft]
+ return startpoints, endpoints
+
+
+class RandomResizedCrop(object):
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
+ if isinstance(size, tuple):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ warnings.warn("range should be of kind (min, max)")
+
+ self.interpolation = interpolation
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ area = img.size[0] * img.size[1]
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= img.size[0] and h <= img.size[1]:
+ i = random.randint(0, img.size[1] - h)
+ j = random.randint(0, img.size[0] - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img.size[0] / img.size[1]
+ if (in_ratio < min(ratio)):
+ w = img.size[0]
+ h = w / min(ratio)
+ elif (in_ratio > max(ratio)):
+ h = img.size[1]
+ w = h * max(ratio)
+ else: # whole image
+ w = img.size[0]
+ h = img.size[1]
+ i = (img.size[1] - h) // 2
+ j = (img.size[0] - w) // 2
+ return i, j, h, w
+
+ def __call__(self, img, mask):
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
+ F.resized_crop(mask, i, j, h, w, self.size, Image.NEAREST)
+
+
+class FiveCrop(object):
+ def __init__(self, size):
+ self.size = size
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+ self.size = size
+
+ def __call__(self, img, mask):
+ return F.five_crop(img, self.size), F.five_crop(mask, self.size)
+
+
+class TenCrop(object):
+ def __init__(self, size, vertical_flip=False):
+ self.size = size
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+ self.size = size
+ self.vertical_flip = vertical_flip
+
+ def __call__(self, img, mask):
+ return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip)
+
+
+class ColorJitter(object):
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+ self.brightness = self._check_input(brightness, 'brightness')
+ self.contrast = self._check_input(contrast, 'contrast')
+ self.saturation = self._check_input(saturation, 'saturation')
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
+ clip_first_on_zero=False)
+
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
+ value = [center - value, center + value]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError("{} values should be between {}".format(name, bound))
+ else:
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
+
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ value = None
+ return value
+
+ @staticmethod
+ def get_params(brightness, contrast, saturation, hue):
+ transforms = []
+
+ if brightness is not None:
+ brightness_factor = random.uniform(brightness[0], brightness[1])
+ transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor)))
+
+ if contrast is not None:
+ contrast_factor = random.uniform(contrast[0], contrast[1])
+ transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor)))
+
+ if saturation is not None:
+ saturation_factor = random.uniform(saturation[0], saturation[1])
+ transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor)))
+
+ if hue is not None:
+ hue_factor = random.uniform(hue[0], hue[1])
+ transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor)))
+
+ random.shuffle(transforms)
+ transform = Compose(transforms)
+
+ return transform
+
+ def __call__(self, img, mask):
+ transform = self.get_params(self.brightness, self.contrast,
+ self.saturation, self.hue)
+ return transform(img, mask)
+
+
+class RandomRotation(object):
+ def __init__(self, degrees, resample=False, expand=False, center=None):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError("If degrees is a single number, it must be positive.")
+ self.degrees = (-degrees, degrees)
+ else:
+ if len(degrees) != 2:
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
+ self.degrees = degrees
+
+ self.resample = resample
+ self.expand = expand
+ self.center = center
+
+ @staticmethod
+ def get_params(degrees):
+ angle = random.uniform(degrees[0], degrees[1])
+
+ return angle
+
+ def __call__(self, img, mask):
+ angle = self.get_params(self.degrees)
+
+ return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \
+ F.rotate(mask, angle, Image.NEAREST, self.expand, self.center)
+
+
+class RandomAffine(object):
+ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError("If degrees is a single number, it must be positive.")
+ self.degrees = (-degrees, degrees)
+ else:
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
+ "degrees should be a list or tuple and it must be of length 2."
+ self.degrees = degrees
+
+ if translate is not None:
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+ "translate should be a list or tuple and it must be of length 2."
+ for t in translate:
+ if not (0.0 <= t <= 1.0):
+ raise ValueError("translation values should be between 0 and 1")
+ self.translate = translate
+
+ if scale is not None:
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
+ "scale should be a list or tuple and it must be of length 2."
+ for s in scale:
+ if s <= 0:
+ raise ValueError("scale values should be positive")
+ self.scale = scale
+
+ if shear is not None:
+ if isinstance(shear, numbers.Number):
+ if shear < 0:
+ raise ValueError("If shear is a single number, it must be positive.")
+ self.shear = (-shear, shear)
+ else:
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
+ "shear should be a list or tuple and it must be of length 2."
+ self.shear = shear
+ else:
+ self.shear = shear
+
+ self.resample = resample
+ self.fillcolor = fillcolor
+
+ @staticmethod
+ def get_params(degrees, translate, scale_ranges, shears, img_size):
+ angle = random.uniform(degrees[0], degrees[1])
+ if translate is not None:
+ max_dx = translate[0] * img_size[0]
+ max_dy = translate[1] * img_size[1]
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
+ np.round(random.uniform(-max_dy, max_dy)))
+ else:
+ translations = (0, 0)
+
+ if scale_ranges is not None:
+ scale = random.uniform(scale_ranges[0], scale_ranges[1])
+ else:
+ scale = 1.0
+
+ if shears is not None:
+ shear = random.uniform(shears[0], shears[1])
+ else:
+ shear = 0.0
+
+ return angle, translations, scale, shear
+
+ def __call__(self, img, mask):
+ ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
+ return F.affine(img, *ret, interpolation=Image.BILINEAR, fill=self.fillcolor), \
+ F.affine(mask, *ret, interpolation=Image.NEAREST, fill=self.fillcolor)
+
+
+
+def get_cub_transform():
+ transform_train = Compose([
+ ToPILImage(),
+ Resize((256, 256)),
+ RandomHorizontalFlip(),
+ RandomAffine(22, scale=(0.75, 1.25)),
+ ToTensor(),
+ Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
+ ])
+ transform_test = Compose([
+ ToPILImage(),
+ Resize((256, 256)),
+ ToTensor(),
+ Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
+ ])
+ return transform_train, transform_test
+
+
+def get_glas_transform():
+ transform_train = Compose([
+ ToPILImage(),
+ # Resize((256, 256)),
+ ColorJitter(brightness=0.2,
+ contrast=0.2,
+ saturation=0.2,
+ hue=0.1),
+ RandomHorizontalFlip(),
+ RandomAffine(5, scale=(0.75, 1.25)),
+ ToTensor(),
+ # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
+ ])
+ transform_test = Compose([
+ ToPILImage(),
+ # Resize((256, 256)),
+ ToTensor(),
+ # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
+ ])
+ return transform_train, transform_test
+
+# def get_glas_transform():
+# transform_train = Compose([
+# ToPILImage(),
+# Resize((256, 256)),
+# ColorJitter(brightness=0.2,
+# contrast=0.2,
+# saturation=0.2,
+# hue=0.1),
+# RandomHorizontalFlip(),
+# RandomAffine(5, scale=(0.75, 1.25)),
+# ToTensor(),
+# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
+# ])
+# transform_test = Compose([
+# ToPILImage(),
+# Resize((256, 256)),
+# ToTensor(),
+# Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
+# ])
+# return transform_train, transform_test
+
+
+def get_monu_transform(args):
+ Idim = int(args['Idim'])
+ transform_train = Compose([
+ ToPILImage(),
+ # Resize((Idim, Idim)),
+ ColorJitter(brightness=0.4,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.1),
+ RandomHorizontalFlip(),
+ RandomAffine(int(args['rotate']), scale=(float(args['scale1']), float(args['scale2']))),
+ ToTensor(),
+ # Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
+ ])
+ transform_test = Compose([
+ ToPILImage(),
+ # Resize((Idim, Idim)),
+ ToTensor(),
+ # Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
+ ])
+ return transform_train, transform_test
+
+
+def get_polyp_transform():
+ transform_train = Compose([
+ # Resize((352, 352)),
+ ToPILImage(),
+ ColorJitter(brightness=0.4,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.1),
+ RandomVerticalFlip(),
+ RandomHorizontalFlip(),
+ RandomAffine(90, scale=(0.75, 1.25)),
+ ToTensor(),
+ # Normalize([105.61, 63.69, 45.67],
+ # [83.08, 55.86, 42.59])
+ ])
+ transform_test = Compose([
+ # Resize((352, 352)),
+ ToPILImage(),
+ ToTensor(),
+ # Normalize([105.61, 63.69, 45.67],
+ # [83.08, 55.86, 42.59])
+ ])
+ return transform_train, transform_test
+
+
+def get_polyp_support_train_transform():
+ transform_train = Compose([
+ ColorJitter(brightness=0.4,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.1),
+ RandomVerticalFlip(),
+ RandomHorizontalFlip(),
+ RandomAffine(90, scale=(0.75, 1.25)),
+ ])
+
+ return transform_train
\ No newline at end of file
diff --git a/dataloaders/SimpleDataset.py b/dataloaders/SimpleDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12092b638bec466fbf40770b3ca46d8707fd54f
--- /dev/null
+++ b/dataloaders/SimpleDataset.py
@@ -0,0 +1,61 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+"""
+simple dataset, gets the images and masks as list together with a transform function that
+shoudl receive both the image and the mask.
+loop means how many times to loop the dataset per epoch
+"""
+
+class SimpleDataset(torch.utils.data.Dataset):
+ def __init__(self, image_list, mask_list, transform=None, norm_func=None, loops=10, modality="", debug=False, image_size=None):
+ self.image_list = image_list
+ if image_size is not None:
+ if len(image_size) == 1:
+ image_size = (image_size, image_size)
+ self.image_size = image_size
+ else:
+ self.image_size = image_list[0].shape[-2:]
+ self.mask_list = mask_list
+ self.transform = transform
+ self.norm_func = norm_func
+ self.loops = loops
+ self.modality = modality
+ self.debug = debug
+
+ def __len__(self):
+ return len(self.image_list) * self.loops
+
+ def __getitem__(self, idx):
+ idx = idx % (len(self.image_list))
+ image = self.image_list[idx].numpy()
+ mask = self.mask_list[idx].to(dtype=torch.uint8).numpy()
+ if self.modality == "CT":
+ image = image.astype(np.uint8)
+ if self.transform:
+ image, mask = self.transform(image, mask)
+ else:
+ # mask = np.repeat(mask[..., np.newaxis], 3, axis=-1)
+ if self.transform:
+ image, mask = self.transform(image, mask)
+
+ if self.norm_func:
+ image = self.norm_func(image)
+
+ mask[mask != 0] = 1
+
+ if self.image_size != image.shape[-2:]:
+ image = torch.nn.functional.interpolate(torch.tensor(image).unsqueeze(0), self.image_size, mode='bilinear').squeeze(0)
+ mask = torch.nn.functional.interpolate(torch.tensor(mask).unsqueeze(0).unsqueeze(0), self.image_size, mode='nearest').squeeze(0).squeeze(0)
+
+ # plot image and mask
+ if self.debug:
+ fig = plt.figure()
+ plt.imshow((image[0]- image.min()) / (image.max() - image.min()))
+ plt.imshow(mask, alpha=0.5)
+ plt.savefig("debug/support_image_mask.png")
+ plt.close(fig)
+
+ image_size = torch.tensor(tuple(image.shape[-2:]))
+ return image, mask
\ No newline at end of file
diff --git a/dataloaders/__init__.py b/dataloaders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dataloaders/augutils.py b/dataloaders/augutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7120afc3fbc23b2a1427b6f71e59be2e7f4cad4
--- /dev/null
+++ b/dataloaders/augutils.py
@@ -0,0 +1,224 @@
+'''
+Utilities for augmentation. Partly credit to Dr. Jo Schlemper
+'''
+from os.path import join
+
+import torch
+import numpy as np
+import torchvision.transforms as deftfx
+import dataloaders.image_transforms as myit
+import copy
+from util.consts import IMG_SIZE
+import time
+import functools
+
+
+def get_sabs_aug(input_size, use_3d=False):
+ sabs_aug = {
+ # turn flipping off as medical data has fixed orientations
+ 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
+ 'affine': {
+ 'rotate': 5,
+ 'shift': (5, 5),
+ 'shear': 5,
+ 'scale': (0.9, 1.2),
+ },
+ 'elastic': {'alpha': 10, 'sigma': 5},
+ 'patch': input_size,
+ 'reduce_2d': True,
+ '3d': use_3d,
+ 'gamma_range': (0.5, 1.5)
+ }
+ return sabs_aug
+
+
+def get_sabs_augv3(input_size):
+ sabs_augv3 = {
+ 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
+ 'affine': {
+ 'rotate': 30,
+ 'shift': (30, 30),
+ 'shear': 30,
+ 'scale': (0.8, 1.3),
+ },
+ 'elastic': {'alpha': 20, 'sigma': 5},
+ 'patch': input_size,
+ 'reduce_2d': True,
+ 'gamma_range': (0.2, 1.8)
+ }
+ return sabs_augv3
+
+
+def get_aug(which_aug, input_size):
+ if which_aug == 'sabs_aug':
+ return get_sabs_aug(input_size)
+ elif which_aug == 'aug_v3':
+ return get_sabs_augv3(input_size)
+ else:
+ raise NotImplementedError
+
+# augs = {
+# 'sabs_aug': get_sabs_aug,
+# 'aug_v3': get_sabs_augv3, # more aggresive
+# }
+
+
+def get_geometric_transformer(aug, order=3):
+ """order: interpolation degree. Select order=0 for augmenting segmentation """
+ affine = aug['aug'].get('affine', 0)
+ alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
+ sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
+ flip = aug['aug'].get(
+ 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})
+
+ tfx = []
+ if 'flip' in aug['aug']:
+ tfx.append(myit.RandomFlip3D(**flip))
+
+ if 'affine' in aug['aug']:
+ tfx.append(myit.RandomAffine(affine.get('rotate'),
+ affine.get('shift'),
+ affine.get('shear'),
+ affine.get('scale'),
+ affine.get('scale_iso', True),
+ order=order))
+
+ if 'elastic' in aug['aug']:
+ tfx.append(myit.ElasticTransform(alpha, sigma))
+ input_transform = deftfx.Compose(tfx)
+ return input_transform
+
+
+def get_geometric_transformer_3d(aug, order=3):
+ """order: interpolation degree. Select order=0 for augmenting segmentation """
+ affine = aug['aug'].get('affine', 0)
+ alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
+ sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
+ flip = aug['aug'].get(
+ 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})
+
+ tfx = []
+ if 'flip' in aug['aug']:
+ tfx.append(myit.RandomFlip3D(**flip))
+
+ if 'affine' in aug['aug']:
+ tfx.append(myit.RandomAffine(affine.get('rotate'),
+ affine.get('shift'),
+ affine.get('shear'),
+ affine.get('scale'),
+ affine.get('scale_iso', True),
+ order=order,
+ use_3d=True))
+
+ if 'elastic' in aug['aug']:
+ tfx.append(myit.ElasticTransform(alpha, sigma))
+ input_transform = deftfx.Compose(tfx)
+ return input_transform
+
+
+def gamma_transform(img, aug):
+ gamma_range = aug['aug']['gamma_range']
+ if isinstance(gamma_range, tuple):
+ gamma = np.random.rand() * \
+ (gamma_range[1] - gamma_range[0]) + gamma_range[0]
+ cmin = img.min()
+ irange = (img.max() - cmin + 1e-5)
+
+ img = img - cmin + 1e-5
+ img = irange * np.power(img * 1.0 / irange, gamma)
+ img = img + cmin
+
+ elif gamma_range == False:
+ pass
+ else:
+ raise ValueError(
+ "Cannot identify gamma transform range {}".format(gamma_range))
+ return img
+
+
+def get_intensity_transformer(aug):
+ """some basic intensity transforms"""
+ return functools.partial(gamma_transform, aug=aug)
+
+
+def transform_with_label(aug):
+ """
+ Doing image geometric transform
+ Proposed image to have the following configurations
+ [H x W x C + CL]
+ Where CL is the number of channels for the label. It is NOT in one-hot form
+ """
+
+ geometric_tfx = get_geometric_transformer(aug)
+ intensity_tfx = get_intensity_transformer(aug)
+
+ def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs):
+ """
+ Args
+ comp: a numpy array with shape [H x W x C + c_label]
+ c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1)
+ nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label
+
+ """
+ comp = copy.deepcopy(comp)
+ if (use_onehot is True) and (c_label != 1):
+ raise NotImplementedError(
+ "Only allow compact label, also the label can only be 2d")
+ assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label"
+
+ # geometric transform
+ _label = comp[..., c_img]
+ _h_label = np.float32(np.arange(nclass) == (_label[..., None]))
+ # _h_label = np.float32(_label[..., None])
+ comp = np.concatenate([comp[..., :c_img], _h_label], -1)
+ comp = geometric_tfx(comp)
+ # round one_hot labels to 0 or 1
+ t_label_h = comp[..., c_img:]
+ t_label_h = np.rint(t_label_h)
+ assert t_label_h.max() <= 1
+ t_img = comp[..., 0: c_img]
+
+ # intensity transform
+ t_img = intensity_tfx(t_img)
+
+ if use_onehot is True:
+ t_label = t_label_h
+ else:
+ t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1)
+ return t_img, t_label
+
+ return transform
+
+
+def transform(scan, label, nclass, geometric_tfx, intensity_tfx):
+ """
+ Args
+ scan: a numpy array with shape [D x H x W x C]
+ label: a numpy array with shape [D x H x W x 1]
+ """
+ assert len(scan.shape) == 4, "Input scan must be 4D"
+ if len(label.shape) == 3:
+ label = np.expand_dims(label, -1)
+
+ # geometric transform
+ comp = copy.deepcopy(np.concatenate(
+ [scan, label], -1)) # [D x H x W x C + 1]
+ _label = comp[..., -1]
+ _h_label = np.float32(np.arange(nclass) == (_label[..., None]))
+ comp = np.concatenate([comp[..., :-1], _h_label], -1)
+ # change comp to be H x W x D x C + 1
+ comp = np.transpose(comp, (1, 2, 0, 3))
+ comp = geometric_tfx(comp)
+ t_label_h = comp[..., 1:]
+ t_label_h = np.rint(t_label_h)
+ assert t_label_h.max() <= 1
+ t_img = comp[..., 0:1]
+
+ # intensity transform
+ t_img = intensity_tfx(t_img)
+ return t_img, t_label_h
+
+
+def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx):
+ return transform(scan, label, nclass, geometric_tfx, intensity_tfx)
+
diff --git a/dataloaders/common.py b/dataloaders/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d65ed55f33aff2684f01cd42208287677d2d0f6a
--- /dev/null
+++ b/dataloaders/common.py
@@ -0,0 +1,263 @@
+"""
+Dataset classes for common uses
+Extended from vanilla PANet code by Wang et al.
+"""
+import random
+import torch
+
+from torch.utils.data import Dataset
+
+class BaseDataset(Dataset):
+ """
+ Base Dataset
+ Args:
+ base_dir:
+ dataset directory
+ """
+ def __init__(self, base_dir):
+ self._base_dir = base_dir
+ self.aux_attrib = {}
+ self.aux_attrib_args = {}
+ self.ids = [] # must be overloaded in subclass
+
+ def add_attrib(self, key, func, func_args):
+ """
+ Add attribute to the data sample dict
+
+ Args:
+ key:
+ key in the data sample dict for the new attribute
+ e.g. sample['click_map'], sample['depth_map']
+ func:
+ function to process a data sample and create an attribute (e.g. user clicks)
+ func_args:
+ extra arguments to pass, expected a dict
+ """
+ if key in self.aux_attrib:
+ raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key))
+ else:
+ self.set_attrib(key, func, func_args)
+
+ def set_attrib(self, key, func, func_args):
+ """
+ Set attribute in the data sample dict
+
+ Args:
+ key:
+ key in the data sample dict for the new attribute
+ e.g. sample['click_map'], sample['depth_map']
+ func:
+ function to process a data sample and create an attribute (e.g. user clicks)
+ func_args:
+ extra arguments to pass, expected a dict
+ """
+ self.aux_attrib[key] = func
+ self.aux_attrib_args[key] = func_args
+
+ def del_attrib(self, key):
+ """
+ Remove attribute in the data sample dict
+
+ Args:
+ key:
+ key in the data sample dict
+ """
+ self.aux_attrib.pop(key)
+ self.aux_attrib_args.pop(key)
+
+ def subsets(self, sub_ids, sub_args_lst=None):
+ """
+ Create subsets by ids
+
+ Args:
+ sub_ids:
+ a sequence of sequences, each sequence contains data ids for one subset
+ sub_args_lst:
+ a list of args for some subset-specific auxiliary attribute function
+ """
+
+ indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids]
+ if sub_args_lst is not None:
+ subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args)
+ for index, args in zip(indices, sub_args_lst)]
+ else:
+ subsets = [Subset(dataset=self, indices=index) for index in indices]
+ return subsets
+
+ def __len__(self):
+ pass
+
+ def __getitem__(self, idx):
+ pass
+
+
+class ReloadPairedDataset(Dataset):
+ """
+ Make pairs of data from dataset
+ Eable only loading part of the entire data in each epoach and then reload to the next part
+ Args:
+ datasets:
+ source datasets, expect a list of Dataset.
+ Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan
+ n_elements:
+ number of elements in a pair
+ curr_max_iters:
+ number of pairs in an epoch
+ pair_based_transforms:
+ some transformation performed on a pair basis, expect a list of functions,
+ each function takes a pair sample and return a transformed one.
+ """
+ def __init__(self, datasets, n_elements, curr_max_iters,
+ pair_based_transforms=None):
+ super().__init__()
+ self.datasets = datasets
+ self.n_datasets = len(self.datasets)
+ self.n_data = [len(dataset) for dataset in self.datasets]
+ self.n_elements = n_elements
+ self.curr_max_iters = curr_max_iters
+ self.pair_based_transforms = pair_based_transforms
+ self.update_index()
+
+ def update_index(self):
+ """
+ update the order of batches for the next episode
+ """
+
+ # update number of elements for each subset
+ if hasattr(self, 'indices'):
+ n_data_old = self.n_data # DEBUG
+ self.n_data = [len(dataset) for dataset in self.datasets]
+
+ if isinstance(self.n_elements, list):
+ self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use
+ for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use
+ for i_iter in range(self.curr_max_iters)] # sample iterations
+
+ elif self.n_elements > self.n_datasets:
+ raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets")
+ else:
+ self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx]))
+ for dataset_idx in random.sample(range(self.n_datasets),
+ k=n_elements)]
+ for i in range(curr_max_iters)]
+
+ def __len__(self):
+ return self.curr_max_iters
+
+ def __getitem__(self, idx):
+ sample = [self.datasets[dataset_idx][data_idx]
+ for dataset_idx, data_idx in self.indices[idx]]
+ if self.pair_based_transforms is not None:
+ for transform, args in self.pair_based_transforms:
+ sample = transform(sample, **args)
+ return sample
+
+class Subset(Dataset):
+ """
+ Subset of a dataset at specified indices. Used for seperating a dataset by class in our context
+
+ Args:
+ dataset:
+ The whole Dataset
+ indices:
+ Indices of samples of the current class in the entire dataset
+ sub_attrib_args:
+ Subset-specific arguments for attribute functions, expected a dict
+ """
+ def __init__(self, dataset, indices, sub_attrib_args=None):
+ self.dataset = dataset
+ self.indices = indices
+ self.sub_attrib_args = sub_attrib_args
+
+ def __getitem__(self, idx):
+ if self.sub_attrib_args is not None:
+ for key in self.sub_attrib_args:
+ # Make sure the dataset already has the corresponding attributes
+ # Here we only make the arguments subset dependent
+ # (i.e. pass different arguments for each subset)
+ self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key])
+ return self.dataset[self.indices[idx]]
+
+ def __len__(self):
+ return len(self.indices)
+
+class ValidationDataset(Dataset):
+ """
+ Dataset for validation
+
+ Args:
+ dataset:
+ source dataset with a __getitem__ method
+ test_classes:
+ test classes
+ npart: int. number of parts, used for evaluation when assigning support images
+
+ """
+ def __init__(self, dataset, test_classes: list, npart: int):
+ super().__init__()
+ self.dataset = dataset
+ self.__curr_cls = None
+ self.test_classes = test_classes
+ self.dataset.aux_attrib = None
+ self.npart = npart
+
+ def set_curr_cls(self, curr_cls):
+ assert curr_cls in self.test_classes
+ self.__curr_cls = curr_cls
+
+ def get_curr_cls(self):
+ return self.__curr_cls
+
+ def read_dataset(self):
+ """
+ override original read_dataset to allow reading with z_margin
+ """
+ raise NotImplementedError
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def label_strip(self, label):
+ """
+ mask unrelated labels out
+ """
+ out = torch.where(label == self.__curr_cls,
+ torch.ones_like(label), torch.zeros_like(label))
+ return out
+
+ def __getitem__(self, idx):
+ if self.__curr_cls is None:
+ raise Exception("Please initialize current class first")
+
+ sample = self.dataset[idx]
+ sample["label"] = self.label_strip( sample["label"] )
+ sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy()
+
+ labelname = self.dataset.all_label_names[self.__curr_cls]
+ z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
+ z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
+ sample["z_min"], sample["z_max"] = z_min, z_max
+ try:
+ part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart))
+ except:
+ part_assign = 0
+ # print("###### DATASET: support only has one valid slice ######")
+ if part_assign < 0:
+ part_assign = 0
+ elif part_assign >= self.npart:
+ part_assign = self.npart - 1
+ sample["part_assign"] = part_assign
+ sample["case"] = sample["scan_id"]
+
+ return sample
+
+ def get_support_set(self, config, n_support=3):
+ support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"])
+
+ support_images = [img for way in support_batched["support_images"] for img in way]
+ support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way]
+ support_scan_id = self.dataset.potential_support_sid
+ return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id}
+
+
+
diff --git a/dataloaders/dataset_utils.py b/dataloaders/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d6f5752f144df91fbed9336c9693b1810e19871
--- /dev/null
+++ b/dataloaders/dataset_utils.py
@@ -0,0 +1,128 @@
+"""
+Utils for datasets
+"""
+import functools
+import numpy as np
+
+import os
+import sys
+import nibabel as nib
+import numpy as np
+import pdb
+import SimpleITK as sitk
+
+DATASET_INFO = {
+ "CHAOST2": {
+ 'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
+ 'REAL_LABEL_NAME': ["BG", "LIVER", "RK", "LK", "SPLEEN"],
+ '_SEP': [0, 4, 8, 12, 16, 20],
+ 'MODALITY': 'MR',
+ 'LABEL_GROUP': {
+ 'pa_all': set(range(1, 5)),
+ 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes
+ 1: set([2, 3]), # lower_abdomen
+ },
+ },
+
+ "SABS": {
+ 'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
+
+ 'REAL_LABEL_NAME': ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC",\
+ "PS_VEIN", "PANCREAS", "AG_R", "AG_L"],
+ '_SEP': [0, 6, 12, 18, 24, 30],
+ 'MODALITY': 'CT',
+ 'LABEL_GROUP':{
+ 'pa_all': set( [1,2,3,6] ),
+ 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing
+ 1: set( [2,3] ), # lower_abdomen
+ }
+ },
+ "LITS17": {
+ 'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
+
+ 'REAL_LABEL_NAME': ["BGD", "LIVER", "TUMOR"],
+ '_SEP': [0, 26, 52, 78, 104],
+ 'MODALITY': 'CT',
+ 'LABEL_GROUP':{
+ 'pa_all': set( [1 , 2] ),
+ 0: set([1 ] ), # liver
+ 1: set( [ 2] ), # tumor
+ 2: set([1,2]) # liver + tumor
+ }
+
+ }
+
+}
+
+def read_nii_bysitk(input_fid, peel_info = False):
+ """ read nii to numpy through simpleitk
+
+ peelinfo: taking direction, origin, spacing and metadata out
+ """
+ img_obj = sitk.ReadImage(input_fid)
+ img_np = sitk.GetArrayFromImage(img_obj)
+ if peel_info:
+ info_obj = {
+ "spacing": img_obj.GetSpacing(),
+ "origin": img_obj.GetOrigin(),
+ "direction": img_obj.GetDirection(),
+ "array_size": img_np.shape
+ }
+ return img_np, info_obj
+ else:
+ return img_np
+
+
+def get_CT_statistics(scan_fids):
+ """
+ As CT are quantitative, get mean and std for CT images for image normalizing
+ As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading
+ """
+ total_val = 0
+ n_pix = 0
+ for fid in scan_fids:
+ in_img = read_nii_bysitk(fid)
+ total_val += in_img.sum()
+ n_pix += np.prod(in_img.shape)
+ del in_img
+ meanval = total_val / n_pix
+
+ total_var = 0
+ for fid in scan_fids:
+ in_img = read_nii_bysitk(fid)
+ total_var += np.sum((in_img - meanval) ** 2 )
+ del in_img
+ var_all = total_var / n_pix
+
+ global_std = var_all ** 0.5
+
+ return meanval, global_std
+
+def MR_normalize(x_in):
+ return (x_in - x_in.mean()) / x_in.std()
+
+def CT_normalize(x_in, ct_mean, ct_std):
+ """
+ Normalizing CT images, based on global statistics
+ """
+ return (x_in - ct_mean) / ct_std
+
+def get_normalize_op(modality, fids, ct_mean=None, ct_std=None):
+ """
+ As title
+ Args:
+ modality: CT or MR
+ fids: fids for the fold
+ """
+ if modality == 'MR':
+ return MR_normalize
+
+ elif modality == 'CT':
+ if ct_mean is None or ct_std is None:
+ ct_mean, ct_std = get_CT_statistics(fids)
+ # debug
+ print(f'###### DEBUG_DATASET CT_STATS NORMALIZED MEAN {ct_mean} STD {ct_std} ######')
+
+ return functools.partial(CT_normalize, ct_mean=ct_mean, ct_std=ct_std)
+
+
diff --git a/dataloaders/dev_customized_med.py b/dataloaders/dev_customized_med.py
new file mode 100644
index 0000000000000000000000000000000000000000..406cc63e017df7d00a93fccaf54293ada1374d91
--- /dev/null
+++ b/dataloaders/dev_customized_med.py
@@ -0,0 +1,250 @@
+"""
+Customized dataset. Extended from vanilla PANet script by Wang et al.
+"""
+
+import os
+import random
+import torch
+import numpy as np
+
+from dataloaders.common import ReloadPairedDataset, ValidationDataset
+from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset
+
+def attrib_basic(_sample, class_id):
+ """
+ Add basic attribute
+ Args:
+ _sample: data sample
+ class_id: class label asscociated with the data
+ (sometimes indicting from which subset the data are drawn)
+ """
+ return {'class_id': class_id}
+
+def getMaskOnly(label, class_id, class_ids):
+ """
+ Generate FG/BG mask from the segmentation mask
+
+ Args:
+ label:
+ semantic mask
+ scribble:
+ scribble mask
+ class_id:
+ semantic class of interest
+ class_ids:
+ all class id in this episode
+ """
+ # Dense Mask
+ fg_mask = torch.where(label == class_id,
+ torch.ones_like(label), torch.zeros_like(label))
+ bg_mask = torch.where(label != class_id,
+ torch.ones_like(label), torch.zeros_like(label))
+ for class_id in class_ids:
+ bg_mask[label == class_id] = 0
+
+ return {'fg_mask': fg_mask,
+ 'bg_mask': bg_mask}
+
+def getMasks(*args, **kwargs):
+ raise NotImplementedError
+
+def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True):
+ """
+ Postprocess paired sample for fewshot settings
+ For now only 1-way is tested but we leave multi-way possible (inherited from original PANet)
+
+ Args:
+ paired_sample:
+ data sample from a PairedDataset
+ n_ways:
+ n-way few-shot learning
+ n_shots:
+ n-shot few-shot learning
+ cnt_query:
+ number of query images for each class in the support set
+ coco:
+ MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension
+ mask_only:
+ only give masks and no scribbles/ instances. Suitable for medical images (for now)
+ """
+ if not mask_only:
+ raise NotImplementedError
+ ###### Compose the support and query image list ######
+ cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries
+
+ # support class ids
+ class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query)
+
+ # support images
+ support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)]
+ for i in range(n_ways)] # fetch support images for each class
+
+ # support image labels
+ if coco:
+ support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]]
+ for j in range(n_shots)] for i in range(n_ways)]
+ else:
+ support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)]
+ for i in range(n_ways)]
+
+ if not mask_only:
+ support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)]
+ for i in range(n_ways)]
+ support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)]
+ for i in range(n_ways)]
+ else:
+ support_insts = []
+
+ # query images, masks and class indices
+ query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways)
+ for j in range(cnt_query[i])]
+ if coco:
+ query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]]
+ for i in range(n_ways) for j in range(cnt_query[i])]
+ else:
+ query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways)
+ for j in range(cnt_query[i])]
+ query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1
+ for x in set(np.unique(query_label)) & set(class_ids)])
+ for query_label in query_labels]
+
+ ###### Generate support image masks ######
+ if not mask_only:
+ support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot],
+ class_ids[way], class_ids)
+ for shot in range(n_shots)] for way in range(n_ways)]
+ else:
+ support_mask = [[getMaskOnly(support_labels[way][shot],
+ class_ids[way], class_ids)
+ for shot in range(n_shots)] for way in range(n_ways)]
+
+ ###### Generate query label (class indices in one episode, i.e. the ground truth)######
+ query_labels_tmp = [torch.zeros_like(x) for x in query_labels]
+ for i, query_label_tmp in enumerate(query_labels_tmp):
+ query_label_tmp[query_labels[i] == 255] = 255
+ for j in range(n_ways):
+ query_label_tmp[query_labels[i] == class_ids[j]] = j + 1
+
+ ###### Generate query mask for each semantic class (including BG) ######
+ # BG class
+ query_masks = [[torch.where(query_label == 0,
+ torch.ones_like(query_label),
+ torch.zeros_like(query_label))[None, ...],]
+ for query_label in query_labels]
+ # Other classes in query image
+ for i, query_label in enumerate(query_labels):
+ for idx in query_cls_idx[i][1:]:
+ mask = torch.where(query_label == class_ids[idx - 1],
+ torch.ones_like(query_label),
+ torch.zeros_like(query_label))[None, ...]
+ query_masks[i].append(mask)
+
+
+ return {'class_ids': class_ids,
+ 'support_images': support_images,
+ 'support_mask': support_mask,
+ 'support_inst': support_insts, # leave these interfaces
+ 'support_scribbles': support_scribbles,
+
+ 'query_images': query_images,
+ 'query_labels': query_labels_tmp,
+ 'query_masks': query_masks,
+ 'query_cls_idx': query_cls_idx,
+ }
+
+
+def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load,
+ transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs):
+ """
+ Dataset wrapper
+ Args:
+ dataset_name:
+ indicates what dataset to use
+ base_dir:
+ dataset directory
+ mode:
+ which mode to use
+ choose from ('train', 'val', 'trainval', 'trainaug')
+ idx_split:
+ index of split
+ scan_per_load:
+ number of scans to load into memory as the dataset is large
+ use that together with reload_buffer
+ transforms:
+ transformations to be performed on images/masks
+ act_labels:
+ active labels involved in training process. Should be a subset of all labels
+ n_ways:
+ n-way few-shot learning, should be no more than # of object class labels
+ n_shots:
+ n-shot few-shot learning
+ max_iters_per_load:
+ number of pairs per load (epoch size)
+ n_queries:
+ number of query images
+ fix_parent_len:
+ fixed length of the parent dataset
+ """
+ med_set = ManualAnnoDataset
+
+
+ mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\
+ scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\
+ exclude_list = exclude_list, **kwargs)
+
+ mydataset.add_attrib('basic', attrib_basic, {})
+
+ # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside
+ subsets = mydataset.subsets([{'basic': {'class_id': ii}}
+ for ii, _ in enumerate(mydataset.label_name)])
+
+ # Choose the classes of queries
+ cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways)
+ # Number of queries for each way
+ # Set the number of images for each class
+ n_elements = [n_shots + x for x in cnt_query] # supports + [i] queries
+ # Create paired dataset. We do not include background.
+ paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load,
+ pair_based_transforms=[
+ (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots,
+ 'cnt_query': cnt_query, 'mask_only': True})])
+ return paired_data, mydataset
+
+def update_loader_dset(loader, parent_set):
+ """
+ Update data loader and the parent dataset behind
+ Args:
+ loader: actual dataloader
+ parent_set: parent dataset which actually stores the data
+ """
+ parent_set.reload_buffer()
+ loader.dataset.update_index()
+ print(f'###### Loader and dataset have been updated ######' )
+
+def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs):
+ """
+ validation set for med images
+ Args:
+ dataset_name:
+ indicates what dataset to use
+ base_dir:
+ SABS dataset directory
+ mode: (original split)
+ which split to use
+ choose from ('train', 'val', 'trainval', 'trainaug')
+ idx_split:
+ index of split
+ scan_per_batch:
+ number of scans to load into memory as the dataset is large
+ use that together with reload_buffer
+ act_labels:
+ actual labels involved in training process. Should be a subset of all labels
+ npart: number of chunks for splitting a 3d volume
+ nsup: number of support scans, equivalent to nshot
+ """
+ mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs)
+ mydataset.add_attrib('basic', attrib_basic, {})
+
+ valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart)
+
+ return valset, mydataset
\ No newline at end of file
diff --git a/dataloaders/image_transforms.py b/dataloaders/image_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..80a2a417c080953ebb9fd2abb67ef91806a12c9e
--- /dev/null
+++ b/dataloaders/image_transforms.py
@@ -0,0 +1,362 @@
+"""
+Image transforms functions for data augmentation
+Credit to Dr. Jo Schlemper
+"""
+
+from collections.abc import Sequence
+import cv2
+import numpy as np
+import scipy
+from scipy.ndimage.filters import gaussian_filter
+from scipy.ndimage.interpolation import map_coordinates
+from numpy.lib.stride_tricks import as_strided
+import numpy as np
+import cv2
+from scipy.ndimage import map_coordinates
+from numpy.lib.stride_tricks import as_strided
+from multiprocessing import Pool
+import albumentations as A
+import time
+
+###### UTILITIES ######
+def random_num_generator(config, random_state=np.random):
+ if config[0] == 'uniform':
+ ret = random_state.uniform(config[1], config[2], 1)[0]
+ elif config[0] == 'lognormal':
+ ret = random_state.lognormal(config[1], config[2], 1)[0]
+ else:
+ #print(config)
+ raise Exception('unsupported format')
+ return ret
+
+def get_translation_matrix(translation):
+ """ translation: [tx, ty] """
+ tx, ty = translation
+ translation_matrix = np.array([[1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1]])
+ return translation_matrix
+
+
+
+def get_rotation_matrix(rotation, input_shape, centred=True):
+ theta = np.pi / 180 * np.array(rotation)
+ if centred:
+ rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1)
+ rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]])
+ else:
+ rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
+ [np.sin(theta), np.cos(theta), 0],
+ [0, 0, 1]])
+ return rotation_matrix
+
+def get_zoom_matrix(zoom, input_shape, centred=True):
+ zx, zy = zoom
+ if centred:
+ zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0])
+ zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]])
+ else:
+ zoom_matrix = np.array([[zx, 0, 0],
+ [0, zy, 0],
+ [0, 0, 1]])
+ return zoom_matrix
+
+def get_shear_matrix(shear_angle):
+ theta = (np.pi * shear_angle) / 180
+ shear_matrix = np.array([[1, -np.sin(theta), 0],
+ [0, np.cos(theta), 0],
+ [0, 0, 1]])
+ return shear_matrix
+
+###### AFFINE TRANSFORM ######
+class RandomAffine(object):
+ """Apply random affine transformation on a numpy.ndarray (H x W x C)
+ Comment by co1818: this is still doing affine on 2d (H x W plane).
+ A same transform is applied to all C channels
+
+ Parameter:
+ ----------
+
+ alpha: Range [0, 4] seems good for small images
+
+ order: interpolation method (c.f. opencv)
+ """
+
+ def __init__(self,
+ rotation_range=None,
+ translation_range=None,
+ shear_range=None,
+ zoom_range=None,
+ zoom_keep_aspect=False,
+ interp='bilinear',
+ use_3d=False,
+ order=3):
+ """
+ Perform an affine transforms.
+
+ Arguments
+ ---------
+ rotation_range : one integer or float
+ image will be rotated randomly between (-degrees, degrees)
+
+ translation_range : (x_shift, y_shift)
+ shifts in pixels
+
+ *NOT TESTED* shear_range : float
+ image will be sheared randomly between (-degrees, degrees)
+
+ zoom_range : (zoom_min, zoom_max)
+ list/tuple with two floats between [0, infinity).
+ first float should be less than the second
+ lower and upper bounds on percent zoom.
+ Anything less than 1.0 will zoom in on the image,
+ anything greater than 1.0 will zoom out on the image.
+ e.g. (0.7, 1.0) will only zoom in,
+ (1.0, 1.4) will only zoom out,
+ (0.7, 1.4) will randomly zoom in or out
+ """
+
+ self.rotation_range = rotation_range
+ self.translation_range = translation_range
+ self.shear_range = shear_range
+ self.zoom_range = zoom_range
+ self.zoom_keep_aspect = zoom_keep_aspect
+ self.interp = interp
+ self.order = order
+ self.use_3d = use_3d
+
+ def build_M(self, input_shape):
+ tfx = []
+ final_tfx = np.eye(3)
+ if self.rotation_range:
+ rot = np.random.uniform(-self.rotation_range, self.rotation_range)
+ tfx.append(get_rotation_matrix(rot, input_shape))
+ if self.translation_range:
+ tx = np.random.uniform(-self.translation_range[0], self.translation_range[0])
+ ty = np.random.uniform(-self.translation_range[1], self.translation_range[1])
+ tfx.append(get_translation_matrix((tx,ty)))
+ if self.shear_range:
+ rot = np.random.uniform(-self.shear_range, self.shear_range)
+ tfx.append(get_shear_matrix(rot))
+ if self.zoom_range:
+ sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1])
+ if self.zoom_keep_aspect:
+ sy = sx
+ else:
+ sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1])
+
+ tfx.append(get_zoom_matrix((sx, sy), input_shape))
+
+ for tfx_mat in tfx:
+ final_tfx = np.dot(tfx_mat, final_tfx)
+
+ return final_tfx.astype(np.float32)
+
+ def __call__(self, image):
+ # build matrix
+ input_shape = image.shape[:2]
+ M = self.build_M(input_shape)
+
+ res = np.zeros_like(image)
+ #if isinstance(self.interp, Sequence):
+ if type(self.order) is list or type(self.order) is tuple:
+ for i, intp in enumerate(self.order):
+ if self.use_3d:
+ res[..., i] = affine_transform_3d_via_M(image[..., i], M[:2], interp=intp)
+ else:
+ res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp)
+ else:
+ # squeeze if needed
+ orig_shape = image.shape
+ image_s = np.squeeze(image)
+ if self.use_3d:
+ res = affine_transform_3d_via_M(image_s, M[:2], interp=self.order)
+ else:
+ res = affine_transform_via_M(image_s, M[:2], interp=self.order)
+ res = res.reshape(orig_shape)
+
+ #res = affine_transform_via_M(image, M[:2], interp=self.order)
+
+ return res
+
+def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST):
+ imshape = image.shape
+ shape_size = imshape[:2]
+
+ # Random affine
+ warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1],
+ flags=interp, borderMode=borderMode)
+
+ #print(imshape, warped.shape)
+
+ warped = warped[..., np.newaxis].reshape(imshape)
+
+ return warped
+
+def affine_transform_3d_via_M(vol, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST):
+ """
+ vol should be of shape (nx, ny, n1, ..., nm)
+ """
+ # go over slice slice
+ res = np.zeros_like(vol)
+ for i in range(vol.shape[2]):
+ res[:, :, i] = affine_transform_via_M(vol[:,:,i], M, borderMode=borderMode, interp=interp)
+
+ return res
+
+
+###### ELASTIC TRANSFORM ######
+def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random):
+ """Elastic deformation of image as described in [Simard2003]_.
+ .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
+ Convolutional Neural Networks applied to Visual Document Analysis", in
+ Proc. of the International Conference on Document Analysis and
+ Recognition, 2003.
+ """
+ assert image.ndim == 3
+ shape = image.shape[:2]
+
+ dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
+ sigma, mode="constant", cval=0) * alpha
+ dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
+ sigma, mode="constant", cval=0) * alpha
+
+ x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
+ indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
+ result = np.empty_like(image)
+ for i in range(image.shape[2]):
+ result[:, :, i] = map_coordinates(
+ image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
+ return result
+
+def elastic_transform_nd_3d(image, **kwargs):
+ """
+ image_w_mask should be of shape (nx, ny, nz, 3)
+ """
+ image_w_mask = image
+ start_time = time.time()
+ elastic_transform = A.ElasticTransform(alpha=10, sigma=20, alpha_affine=15, interpolation=1, border_mode=4, always_apply=True, p=0.5)
+ # print(f"elastic transform initilization took {time.time() - start_time} seconds")
+ img = image_w_mask[..., 0]
+ label = image_w_mask[..., -1]
+ transformed = elastic_transform(image=img, mask=label)
+ t_img = transformed['image'][..., np.newaxis]
+ t_mask = transformed['mask'][..., np.newaxis]
+ t_mask_bg = 1 - t_mask
+ t_mask = np.concatenate([t_mask_bg, t_mask], axis=-1)
+
+ comp = np.concatenate([t_img, t_mask], axis=-1)
+ return comp
+
+def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False):
+ """Expects data to be (nx, ny, n1 ,..., nm)
+ params:
+ ------
+
+ alpha:
+ the scaling parameter.
+ E.g.: alpha=2 => distorts images up to 2x scaling
+
+ sigma:
+ standard deviation of gaussian filter.
+ E.g.
+ low (sig~=1e-3) => no smoothing, pixelated.
+ high (1/5 * imsize) => smooth, more like affine.
+ very high (1/2*im_size) => translation
+ """
+
+ if random_state is None:
+ random_state = np.random.RandomState(None)
+
+ shape = image.shape
+ imsize = shape[:2]
+ dim = shape[2:]
+
+ # Random affine
+ blur_size = int(4*sigma) | 1
+ dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1,
+ ksize=(blur_size, blur_size), sigmaX=sigma) * alpha
+ dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1,
+ ksize=(blur_size, blur_size), sigmaX=sigma) * alpha
+
+ # use as_strided to copy things over across n1...nn channels
+ dx = as_strided(dx.astype(np.float32),
+ strides=(0,) * len(dim) + (4*shape[1], 4),
+ shape=dim+(shape[0], shape[1]))
+ dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim))))
+
+ dy = as_strided(dy.astype(np.float32),
+ strides=(0,) * len(dim) + (4*shape[1], 4),
+ shape=dim+(shape[0], shape[1]))
+ dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim))))
+
+ coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim])
+ indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:],
+ [dy, dx] + [0] * len(dim))]
+
+ if lazy:
+ return indices
+ res = map_coordinates(image, indices, order=order, mode='reflect').reshape(shape)
+ return res
+
+class ElasticTransform(object):
+ """Apply elastic transformation on a numpy.ndarray (H x W x C)
+ """
+
+ def __init__(self, alpha, sigma, order=1):
+ self.alpha = alpha
+ self.sigma = sigma
+ self.order = order
+
+ def __call__(self, image):
+ if isinstance(self.alpha, Sequence):
+ alpha = random_num_generator(self.alpha)
+ else:
+ alpha = self.alpha
+ if isinstance(self.sigma, Sequence):
+ sigma = random_num_generator(self.sigma)
+ else:
+ sigma = self.sigma
+ return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order)
+
+class RandomFlip3D(object):
+
+ def __init__(self, h=True, v=True, t=True, p=0.5):
+ """
+ Randomly flip an image horizontally and/or vertically with
+ some probability.
+
+ Arguments
+ ---------
+ h : boolean
+ whether to horizontally flip w/ probability p
+
+ v : boolean
+ whether to vertically flip w/ probability p
+
+ p : float between [0,1]
+ probability with which to apply allowed flipping operations
+ """
+ self.horizontal = h
+ self.vertical = v
+ self.depth = t
+ self.p = p
+
+ def __call__(self, x, y=None):
+ # horizontal flip with p = self.p
+ if self.horizontal:
+ if np.random.random() < self.p:
+ x = x[::-1, ...]
+
+ # vertical flip with p = self.p
+ if self.vertical:
+ if np.random.random() < self.p:
+ x = x[:, ::-1, ...]
+
+ if self.depth:
+ if np.random.random() < self.p:
+ x = x[..., ::-1]
+
+ return x
+
+
diff --git a/dataloaders/niftiio.py b/dataloaders/niftiio.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5764dd3b073d039313553e85fb1d1c86905e913
--- /dev/null
+++ b/dataloaders/niftiio.py
@@ -0,0 +1,48 @@
+"""
+Utils for datasets
+"""
+import numpy as np
+
+import numpy as np
+import SimpleITK as sitk
+
+
+def read_nii_bysitk(input_fid, peel_info = False):
+ """ read nii to numpy through simpleitk
+ peelinfo: taking direction, origin, spacing and metadata out
+ """
+ img_obj = sitk.ReadImage(input_fid)
+ img_np = sitk.GetArrayFromImage(img_obj)
+ if peel_info:
+ info_obj = {
+ "spacing": img_obj.GetSpacing(),
+ "origin": img_obj.GetOrigin(),
+ "direction": img_obj.GetDirection(),
+ "array_size": img_np.shape
+ }
+ return img_np, info_obj
+ else:
+ return img_np
+
+def convert_to_sitk(input_mat, peeled_info):
+ """
+ write a numpy array to sitk image object with essential meta-data
+ """
+ nii_obj = sitk.GetImageFromArray(input_mat)
+ if peeled_info:
+ nii_obj.SetSpacing( peeled_info["spacing"] )
+ nii_obj.SetOrigin( peeled_info["origin"] )
+ nii_obj.SetDirection(peeled_info["direction"] )
+ return nii_obj
+
+def np2itk(img, ref_obj):
+ """
+ img: numpy array
+ ref_obj: reference sitk object for copying information from
+ """
+ itk_obj = sitk.GetImageFromArray(img)
+ itk_obj.SetSpacing( ref_obj.GetSpacing() )
+ itk_obj.SetOrigin( ref_obj.GetOrigin() )
+ itk_obj.SetDirection( ref_obj.GetDirection() )
+ return itk_obj
+
diff --git a/models/ProtoMedSAM.py b/models/ProtoMedSAM.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ce75d585aaa2d563324531213380b6cd5687ff
--- /dev/null
+++ b/models/ProtoMedSAM.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import matplotlib.pyplot as plt
+from models.ProtoSAM import ModelWrapper
+from segment_anything import sam_model_registry
+from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components
+
+class ProtoMedSAM(nn.Module):
+ def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False, coarse_pred_only=False):
+ super().__init__()
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ self.image_size = image_size
+ self.coarse_segmentation_model = coarse_segmentation_model
+ self.get_sam(sam_pretrained_path)
+ self.coarse_pred_only = coarse_pred_only
+ self.debug = debug
+ self.use_cca = use_cca
+
+
+ def get_sam(self, checkpoint_path):
+ model_type="vit_b" # TODO make generic?
+ if 'vit_h' in checkpoint_path:
+ model_type = "vit_h"
+ self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
+
+
+ torch.no_grad()
+ def medsam_inference(self, img_embed, box_1024, H, W, query_label=None):
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
+ if len(box_torch.shape) == 2:
+ box_torch = box_torch[:, None, :] # (B, 1, 4)
+
+ sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder(
+ points=None,
+ boxes=box_torch,
+ masks=None,
+ )
+ low_res_logits, conf = self.medsam.mask_decoder(
+ image_embeddings=img_embed, # (B, 256, 64, 64)
+ image_pe=self.medsam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
+ multimask_output=True if query_label is not None else False,
+ )
+
+ low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
+
+ low_res_pred = F.interpolate(
+ low_res_pred,
+ size=(H, W),
+ mode="bilinear",
+ align_corners=False,
+ ) # (1, 1, gt.shape)
+ low_res_pred = low_res_pred.squeeze().cpu() # (256, 256)
+
+ low_res_pred = low_res_pred.numpy()
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
+
+ if query_label is not None:
+ medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :]
+
+ return medsam_seg, conf.cpu().detach().numpy()
+
+ def get_iou(self, pred, label):
+ """
+ pred np array shape h,w type uint8
+ label np array shpae h,w type uiint8
+ """
+ tp = np.logical_and(pred, label).sum()
+ fp = np.logical_and(pred, 1-label).sum()
+ fn = np.logical_and(1-pred, label).sum()
+ iou = tp / (tp + fp + fn)
+ return iou
+
+ def get_best_mask(self, masks, labels):
+ """
+ masks np shape ( B, h, w)
+ labels torch shape (1, H, W)
+ """
+ np_labels = labels[0].clone().detach().cpu().numpy()
+ best_iou, best_mask = 0, None
+ for mask in masks:
+ iou = self.get_iou(mask, np_labels)
+ if iou > best_iou:
+ best_iou = iou
+ best_mask = mask
+
+ return best_mask
+
+ def get_bbox(self, pred):
+ """
+ pred is tensor of shape (H,W) - 1 is fg, 0 is bg.
+ return bbox of pred s.t np.array([xmin, y_min, xmax, ymax])
+ """
+ if isinstance(pred, np.ndarray):
+ pred = torch.from_numpy(pred)
+ if pred.max() == 0:
+ return None
+ indices = torch.nonzero(pred)
+ ymin, xmin = indices.min(dim=0)[0]
+ ymax, xmax = indices.max(dim=0)[0]
+ return np.array([xmin, ymin, xmax, ymax])
+
+
+ def get_bbox_per_cc(self, conn_components):
+ """
+ conn_components: output of cca function
+ return list of bboxes per connected component, each bbox is a list of 2d points
+ """
+ bboxes = []
+ for i in range(1, conn_components[0]):
+ # get the indices of the foreground points
+ pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8)
+ bboxes.append(self.get_bbox(pred))
+
+ bboxes = np.array(bboxes)
+ return bboxes
+
+ def forward(self, query_image, coarse_model_input, degrees_rotate=0):
+ """
+ query_image: 3d tensor of shape (1, 3, H, W)
+ images should be normalized with mean and std but not to [0, 1]?
+ """
+ original_size = query_image.shape[-2]
+ # rotate query_image by degrees_rotate
+ rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
+ # print(f"rotating query image took {time.time() - start_time} seconds")
+ coarse_model_input.set_query_images(rotated_img)
+ output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
+ # print(f"ALPNet took {time.time() - start_time} seconds")
+
+ if degrees_rotate != 0:
+ output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
+ # print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
+ else:
+ output_logits = output_logits_rot
+
+ # check if softmax is needed
+ # output_p = output_logits.softmax(dim=1)
+ output_p = output_logits
+ pred = output_logits.argmax(dim=1)[0]
+ if self.debug:
+ _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
+ plt.subplot(132)
+ plt.imshow(query_image[0,0].detach().cpu())
+ plt.imshow(_pred, alpha=0.5)
+ plt.subplot(131)
+ # plot heatmap of prob of being fg
+ plt.imshow(output_p[0, 1].detach().cpu())
+ # plot rotated query image and rotated pred
+ output_p_rot = output_logits_rot.softmax(dim=1)
+ _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
+ _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
+ plt.subplot(133)
+ plt.imshow(rotated_img[0, 0].detach().cpu())
+ plt.imshow(_pred_rot, alpha=0.5)
+ plt.savefig('debug/coarse_pred.png')
+ plt.close()
+
+ if self.coarse_pred_only:
+ output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
+ pred = output_logits.argmax(dim=1)[0]
+ conf = get_confidence_from_logits(output_logits)
+ if self.use_cca:
+ _pred = np.array(pred.detach().cpu())
+ _pred, conf = cca(_pred, output_logits, return_conf=True)
+ pred = torch.from_numpy(_pred)
+ if self.training:
+ return output_logits, [conf]
+ return pred, [conf]
+
+ if query_image.shape[-2:] != self.image_size:
+ query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
+ output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
+ if need_softmax(output_logits):
+ output_logits = output_logits.softmax(dim=1)
+
+ output_p = output_logits
+ pred = output_p.argmax(dim=1)[0]
+
+ _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
+ if self.use_cca:
+ conn_components = cca(_pred, output_logits, return_cc=True)
+ conf=None
+ else:
+ conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
+ if self.debug:
+ plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
+ # print(f"connected components took {time.time() - start_time} seconds")
+
+ if _pred.max() == 0:
+ if output_p.shape[-2:] != original_size:
+ output_p = F.interpolate(output_p, size=original_size, mode='bilinear')
+ return output_p.argmax(dim=1)[0], [0]
+
+ H, W = query_image.shape[-2:]
+ # bbox = self.get_bbox(_pred)
+ bbox = self.get_bbox_per_cc(conn_components)
+ bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
+ query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
+ with torch.no_grad():
+ image_embedding = self.medsam.image_encoder(query_image)
+
+ medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W)
+
+ if self.debug:
+ fig, ax = plt.subplots(1, 2)
+ ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
+ show_mask(medsam_seg, ax[0])
+ ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
+ show_box(bbox[0], ax[1])
+ plt.savefig('debug/medsam_pred.png')
+ plt.close()
+
+ medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
+ if medsam_seg.shape[-2:] != original_size:
+ medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
+
+ return medsam_seg, [conf]
+
+ def segment_all(self, query_image, query_label):
+ H, W = query_image.shape[-2:]
+ # bbox = self.get_bbox(_pred)
+ # bbox = self.get_bbox_per_cc(conn_components)
+ # bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
+ bbox = np.array([[0, 0, W, H]])
+ query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
+ with torch.no_grad():
+ image_embedding = self.medsam.image_encoder(query_image)
+
+ medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label)
+
+ if self.debug:
+ fig, ax = plt.subplots(1, 2)
+ ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
+ show_mask(medsam_seg, ax[0])
+ ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
+ show_box(bbox[0], ax[1])
+ plt.savefig('debug/medsam_pred.png')
+ plt.close()
+
+ medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
+ if medsam_seg.shape[-2:] != (H, W):
+ medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0]
+
+ return medsam_seg.view(H,W), [conf]
+
+
+def show_mask(mask, ax, random_color=False):
+ if random_color:
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
+ else:
+ color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
+ h, w = mask.shape[-2:]
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+ ax.imshow(mask_image)
+
+
+def show_box(box, ax):
+ x0, y0 = box[0], box[1]
+ w, h = box[2] - box[0], box[3] - box[1]
+ ax.add_patch(
+ plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
+ )
diff --git a/models/ProtoSAM.py b/models/ProtoSAM.py
new file mode 100644
index 0000000000000000000000000000000000000000..6617ab96774857c486f0ace352eb963931ef78e7
--- /dev/null
+++ b/models/ProtoSAM.py
@@ -0,0 +1,708 @@
+import warnings
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import matplotlib.pyplot as plt
+import numpy as np
+from models.grid_proto_fewshot import FewShotSeg
+from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
+from models.SamWrapper import SamWrapper
+from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits
+from util.lora import inject_trainable_lora
+from models.segment_anything.utils.transforms import ResizeLongestSide
+import cv2
+import time
+from abc import ABC, abstractmethod
+
+CONF_MODE="conf"
+CENTROID_MODE="centroid"
+BOTH_MODE="both"
+POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE)
+
+TYPE_ALPNET="alpnet"
+TYPE_SAM="sam"
+
+def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"):
+ num_labels, labels, stats, centroids = cca_output
+ # Create an output image with random colors for each component
+ output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8)
+ for label in range(1, num_labels): # Start from 1 to skip the background
+ mask = labels == label
+ output_image[mask] = np.random.randint(0, 255, size=3)
+
+ # Plotting the original and the colored components image
+ plt.figure(figsize=(10, 5))
+ plt.subplot(121), plt.imshow(original_image), plt.title('Original Image')
+ plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components')
+ if confidences is not None:
+ # Plot the axes color chart with the confidences, use the same colors as the connected components
+ plt.subplot(122)
+ scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet')
+ plt.colorbar(scatter)
+
+ plt.savefig(title)
+ plt.close()
+
+class SegmentationInput(ABC):
+ @abstractmethod
+ def set_query_images(self, query_images):
+ pass
+
+ def to(self, device):
+ pass
+
+class SegmentationOutput(ABC):
+ @abstractmethod
+ def get_prediction(self):
+ pass
+
+class ALPNetInput(SegmentationInput): # for alpnet
+ def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None):
+ self.supp_imgs = [support_images]
+ self.fore_mask = [support_labels]
+ self.back_mask = [[1 - sup_labels for sup_labels in support_labels]]
+ self.qry_imgs = [query_images]
+ self.isval = isval
+ self.val_wsize = val_wsize
+ self.show_viz = show_viz
+ self.supp_fts = supp_fts
+
+ def set_query_images(self, query_images):
+ self.qry_imgs = [query_images]
+
+ def to(self, device):
+ self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]]
+ self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]]
+ self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]]
+ self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs]
+ if self.supp_fts is not None:
+ self.supp_fts = self.supp_fts.to(device)
+
+class ALPNetOutput(SegmentationOutput):
+ def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts):
+ self.pred = pred
+ self.align_loss = align_loss
+ self.sim_maps = sim_maps
+ self.assign_maps = assign_maps
+ self.proto_grid = proto_grid
+ self.supp_fts = supp_fts
+ self.qry_fts = qry_fts
+
+ def get_prediction(self):
+ return self.pred
+
+class SAMWrapperInput(SegmentationInput):
+ def __init__(self, image, image_labels):
+ self.image = image
+ self.image_labels = image_labels
+
+ def set_query_images(self, query_images):
+ B, C, H, W = query_images.shape
+ if isinstance(query_images, torch.Tensor):
+ query_images = query_images.cpu().detach().numpy()
+ assert B == 1, "batch size must be 1"
+ query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255
+ query_images = query_images.astype(np.uint8)
+ self.image = np.transpose(query_images[0], (1, 2, 0))
+
+ def to(self, device):
+ pass
+
+
+class InputFactory(ABC):
+ @staticmethod
+ def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None):
+
+ if input_type == TYPE_ALPNET:
+ return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts)
+ elif input_type == TYPE_SAM:
+ qimg = np.array(query_image.detach().cpu())
+ B,C,H,W = qimg.shape
+ assert B == 1, "batch size must be 1"
+ gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W)
+ assert np.unique(gts).shape[0] <= 2, "support labels must be binary"
+ gts[gts > 0] = 1
+ qimg = qimg.reshape(H,W,C)
+ qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255
+ qimg = qimg.astype(np.uint8)
+ return SAMWrapperInput(qimg, gts)
+ else:
+ raise ValueError(f"input_type not supported")
+
+
+class ModelWrapper(ABC):
+ def __init__(self, model):
+ self.model = model
+
+ def __call__(self, input_data: SegmentationInput)->SegmentationOutput:
+ pass
+
+ def state_dict(self):
+ return self.model.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.model.load_state_dict(state_dict)
+
+ def eval(self):
+ self.model.eval()
+
+ def train(self):
+ self.model.train()
+
+ def parameters(self):
+ pass
+
+class ALPNetWrapper(ModelWrapper):
+ def __init__(self, model: FewShotSeg):
+ super().__init__(model)
+
+ def __call__(self, input_data: ALPNetInput):
+ output = self.model(**input_data.__dict__)
+ output = ALPNetOutput(*output)
+ return output.pred
+
+ def parameters(self):
+ return self.model.encoder.parameters()
+
+ def train(self):
+ self.model.encoder.train()
+
+class SamWrapperWrapper(ModelWrapper):
+ def __init__(self, model:SamWrapper):
+ super().__init__(model)
+
+ def __call__(self, input_data: SAMWrapperInput):
+ pred = self.model(**input_data.__dict__)
+ # make pred look like logits
+ pred = torch.tensor(pred).float()[None, None, ...]
+ pred = torch.cat([1-pred, pred], dim=1)
+ return pred
+
+ def to(self, device):
+ self.model.sam.to(device)
+
+class ProtoSAM(nn.Module):
+ def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ):
+ super().__init__()
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ self.image_size = image_size
+ self.coarse_segmentation_model = coarse_segmentation_model
+ self.get_sam(sam_pretrained_path, use_sam_trans)
+ self.num_points_for_sam = num_points_for_sam
+ self.use_points = use_points
+ self.use_bbox = use_bbox # if False then uses points
+ self.use_mask = use_mask
+ self.use_neg_points = use_neg_points
+ assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask"
+ self.use_cca = use_cca
+ self.point_mode = point_mode
+ if self.point_mode not in POINT_MODES:
+ raise ValueError(f"point mode must be one of {POINT_MODES}")
+ self.debug=debug
+ self.coarse_pred_only = coarse_pred_only
+
+ def get_sam(self, checkpoint_path, use_sam_trans):
+ model_type="vit_b" # TODO make generic?
+ if 'vit_h' in checkpoint_path:
+ model_type = "vit_h"
+ self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
+ self.predictor = SamPredictor(self.sam)
+ self.sam.requires_grad_(False)
+ if use_sam_trans:
+ # sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1])
+ sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size)
+ sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1)
+ sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1)
+ else:
+ sam_trans = None
+
+ self.sam_trans = sam_trans
+
+ def get_bbox(self, pred):
+ '''
+ pred tensor of shape (H, W) where 1 represents foreground and 0 represents background
+ returns a list of 2d points representing the bbox
+ '''
+ if isinstance(pred, np.ndarray):
+ pred = torch.tensor(pred)
+ # get the indices of the foreground points
+ indices = torch.nonzero(pred)
+ # get the min and max of the indices
+ min_x = indices[:, 1].min()
+ max_x = indices[:, 1].max()
+ min_y = indices[:, 0].min()
+ max_y = indices[:, 0].max()
+ # get the bbox
+ bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]]
+
+
+ return bbox
+
+ def get_bbox_per_cc(self, conn_components):
+ """
+ conn_components: output of cca function
+ return list of bboxes per connected component, each bbox is a list of 2d points
+ """
+ bboxes = []
+ for i in range(1, conn_components[0]):
+ # get the indices of the foreground points
+ indices = torch.nonzero(torch.tensor(conn_components[1] == i))
+ # get the min and max of the indices
+ min_x = indices[:, 1].min()
+ max_x = indices[:, 1].max()
+ min_y = indices[:, 0].min()
+ max_y = indices[:, 0].max()
+ # get the bbox
+ # bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]]
+ # bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]]
+ # bbox should be in a XYXY format
+ bbox = [min_x, min_y, max_x, max_y]
+ bboxes.append(bbox)
+
+ bboxes = np.array(bboxes)
+ return bboxes
+
+ def get_most_conf_points(self, output_p_fg, pred, k):
+ '''
+ get the k most confident points from pred
+ output_p: 3d tensor of shape (H, W)
+ pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background
+ '''
+ # Create a mask where pred is 1
+ mask = pred.bool()
+
+ # Apply the mask to output_p_fg
+ masked_output_p_fg = output_p_fg[mask]
+ if masked_output_p_fg.numel() == 0:
+ return None, None
+ # Get the top k probabilities and their indices
+ confidences, indices = torch.topk(masked_output_p_fg, k)
+
+ # Get the locations of the top k points in xy format
+ locations = torch.nonzero(mask)[indices]
+ # convert locations to xy format
+ locations = locations[:, [1, 0]]
+ # convert locations to list of lists
+ # points = [loc.tolist() for loc in locations]
+
+ return locations.numpy(), [float(conf.item()) for conf in confidences]
+
+
+ def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None):
+ '''
+ points: np array of shape (N, 2) where each row is a point in xy format
+ pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background
+ image: 2d tensor of shape (H,W) representing the image
+ bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format
+ '''
+ warnings.filterwarnings('ignore', category=UserWarning)
+ if isinstance(pred, torch.Tensor):
+ pred = pred.cpu().detach().numpy()
+ if len(image.shape) == 3 and image.shape[0] == 3:
+ image = image.permute(1, 2, 0)
+ if title is None:
+ title="debug/most_conf_points.png"
+
+ fig = plt.figure()
+ image = (image - image.min()) / (image.max() - image.min())
+ plt.imshow(image)
+ plt.imshow(pred, alpha=0.5)
+ for i, point in enumerate(points):
+ plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red')
+ if confidences is not None:
+ plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red')
+ # assume points is a list of lists
+ if bboxes is not None:
+ for bbox in bboxes:
+ if bbox is None:
+ continue
+ bbox = np.array(bbox)
+ # plt.scatter(bbox[:, 1], bbox[:, 0], c='red')
+ # plot a line connecting the points
+ box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]])
+ box = np.vstack([box, box[0]])
+ plt.plot(box[:, 0], box[:, 1], c='green')
+ plt.colorbar()
+ fig.savefig(title)
+ plt.close(fig)
+
+ def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None):
+ if len(image.shape) == 3:
+ image = image.permute(1, 2, 0)
+ image = (image - image.min()) / (image.max() - image.min())
+ for i, (mask, score) in enumerate(zip(masks, scores)):
+ plt.figure(figsize=(10,10))
+ plt.imshow(image)
+ show_mask(mask, plt.gca())
+ if input_point is not None:
+ show_points(input_point, input_label, plt.gca())
+ if input_box is not None:
+ show_box(input_box, plt.gca())
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
+ # plt.axis('off')
+ plt.savefig(f'debug/sam_mask_{i+1}.png')
+ plt.close()
+ if i > 5:
+ break
+
+ def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1):
+ """
+ args:
+ conn_components: output of cca function
+ output_p: 3d tensor of shape (1, 2, H, W)
+ get_neg_points: bool, if True then return the negative points
+ l: int, number of negative points to get
+ """
+ sam_input_points = []
+ sam_neg_points = []
+ fg_p = output_p[0, 1].detach().cpu()
+
+ if get_neg_points:
+ # get global negative points
+ bg_p = output_p[0, 0].detach().cpu()
+ bg_p[bg_p < 0.95] = 0
+ bg_pred = torch.where(bg_p > 0, 1, 0)
+ glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1)
+ if self.debug:
+ # plot the bg_p as a heatmap
+ plt.figure()
+ plt.imshow(bg_p)
+ plt.colorbar()
+ plt.savefig('debug/bg_p_heatmap.png')
+ plt.close()
+
+ for i, cc_id in enumerate(np.unique(conn_components[1])):
+ # get self.num_points_for_sam most confident points from pred
+ if cc_id == 0:
+ continue # skip background
+ pred = torch.tensor(conn_components[1] == cc_id).float()
+
+ if self.point_mode == CONF_MODE:
+ points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2)
+ elif self.point_mode == CENTROID_MODE:
+ points = conn_components[3][cc_id][None, :] # (1, 2)
+ confidences = [1 for _ in range(len(points))]
+ elif self.point_mode == BOTH_MODE:
+ points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam)
+ point = conn_components[3][cc_id][None, :]
+ points = np.vstack([points, point]) # (N+1, 2)
+ confidences.append(1)
+ else:
+ raise NotImplementedError(f"point mode {self.point_mode} not implemented")
+ sam_input_points.append(np.array(points))
+
+ if get_neg_points:
+ pred_uint8 = (pred.numpy() * 255).astype(np.uint8)
+
+ # Dilate the mask to expand it
+ kernel_size = 3 # Size of the dilation kernel, adjust accordingly
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
+ dilation_iterations = 10 # Number of times dilation is applied, adjust as needed
+ dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations)
+
+ # Subtract the original mask from the dilated mask
+ # This will give a boundary that is only outside the original mask
+ outside_boundary = dilated_mask - pred_uint8
+
+ # Convert back to torch tensor and normalize
+ boundary = torch.tensor(outside_boundary).float() / 255
+ try:
+ bg_p = output_p[0, 0].detach().cpu()
+ neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l)
+ except RuntimeError as e:
+ # make each point (None, None)
+ neg_points = None
+ # append global negative points to the negative points
+ if neg_points is not None and glob_neg_points is not None:
+ neg_points = np.vstack([neg_points, glob_neg_points])
+ else:
+ neg_points = glob_neg_points if neg_points is None else neg_points
+ if self.debug and neg_points is not None:
+ # draw an image with 2 subplots, one is the pred and the other is the boundary
+ plt.figure()
+ plt.subplot(121)
+ plt.imshow(pred)
+ plt.imshow(boundary, alpha=0.5)
+ # plot the neg points
+ plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red')
+ plt.subplot(122)
+ plt.imshow(pred)
+ plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red')
+ plt.savefig('debug/pred_and_boundary.png')
+ plt.close()
+ sam_neg_points.append(neg_points)
+ else:
+ # create a list of None same shape as points
+ sam_neg_points = [None for _ in range(len(sam_input_points))]
+
+ sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))])
+ sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2)
+ # if get_neg_points:
+ sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None
+ if sam_neg_input_points is not None:
+ sam_neg_input_points = sam_neg_points
+ sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) )
+ else:
+ sam_neg_input_points = None
+ sam_neg_input_labels = None
+
+ return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels
+
+ def get_sam_input_mask(self, conn_components):
+ sam_input_masks = []
+ sam_input_mask_lables = []
+ for i, cc_id in enumerate(np.unique(conn_components[1])):
+ # get self.num_points_for_sam most confident points from pred
+ if cc_id == 0:
+ continue
+ pred = torch.tensor(conn_components[1] == cc_id).float()
+ sam_input_masks.append(pred)
+ sam_input_mask_lables.append(cc_id)
+
+ sam_input_masks = np.stack(sam_input_masks)
+ sam_input_mask_lables = np.array(sam_input_mask_lables)
+
+ return sam_input_masks, sam_input_mask_lables
+
+ def predict_w_masks(self, sam_input_masks, qry_img, original_size):
+ masks = []
+ scores = []
+ for in_mask in sam_input_masks:
+ in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST)
+ in_mask[in_mask == 1] = 10
+ in_mask[in_mask == 0] = -8
+ assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8
+ self.predictor.set_image(qry_img)
+ mask, score, _ = self.predictor.predict(
+ mask_input=in_mask[None, ...].astype(np.uint8),
+ multimask_output=True)
+ # get max index from score
+ if self.debug:
+ # plot each channel of mask
+ fig, ax = plt.subplots(1, 4, figsize=(15, 5))
+ for i in range(mask.shape[0]):
+ ax[i].imshow(qry_img)
+ ax[i].imshow(mask[i], alpha=0.5)
+ ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18)
+ # ax[i].axis('off')
+ ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST))
+ fig.savefig(f'debug/sam_mask_from_mask_prompts.png')
+ plt.close(fig)
+
+
+ max_index = score.argmax()
+ masks.append(mask[max_index])
+ scores.append(score[max_index])
+
+ return masks, scores
+
+ def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False):
+ masks, scores = [], []
+ self.predictor.set_image(qry_img)
+ # if sam_input_points is None:
+ # sam_input_points = [None for _ in range(len(bboxes))]
+ for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points):
+ assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8
+ points = point
+ point_labels = np.array([1] * len(point)) if point is not None else None
+ if self.use_neg_points:
+ neg_points = [npoint for npoint in neg_point if None not in npoint]
+ points = np.vstack([point, *neg_points])
+ point_labels = np.array([1] * len(point) + [0] * len(neg_points))
+ if self.debug:
+ self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points
+ mask, score, _ = self.predictor.predict(
+ point_coords=points,
+ point_labels=point_labels,
+ # box=bbox_xyxy[None, :] if bbox_xyxy is not None else None,
+ box = bbox_xyxy if bbox_xyxy is not None else None,
+ # mask_input=sam_mask_input,
+ return_logits=return_logits,
+ multimask_output=False if self.use_cca else True
+ )
+ # best_pred_idx = np.argmax(score)
+ best_pred_idx = 0
+ masks.append(mask[best_pred_idx])
+ scores.append(score[best_pred_idx])
+
+ if self.debug:
+ # pass
+ self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None)
+
+ return masks, scores
+
+
+ def forward(self, query_image, coarse_model_input, degrees_rotate=0):
+ """
+ query_image: 3d tensor of shape (1, 3, H, W)
+ images should be normalized with mean and std but not to [0, 1]?
+ """
+ original_size = query_image.shape[-2]
+ # rotate query_image by degrees_rotate
+ start_time = time.time()
+ rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
+ # print(f"rotating query image took {time.time() - start_time} seconds")
+ start_time = time.time()
+ coarse_model_input.set_query_images(rotated_img)
+ output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
+ # print(f"ALPNet took {time.time() - start_time} seconds")
+
+ if degrees_rotate != 0:
+ start_time = time.time()
+ output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
+ # print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
+ else:
+ output_logits = output_logits_rot
+
+ # check if softmax is needed
+ output_p = output_logits.softmax(dim=1)
+ # output_p = output_logits
+ pred = output_logits.argmax(dim=1)[0]
+ if self.debug:
+ _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
+ plt.subplot(132)
+ plt.imshow(query_image[0,0].detach().cpu())
+ plt.imshow(_pred, alpha=0.5)
+ plt.subplot(131)
+ # plot heatmap of prob of being fg
+ plt.imshow(output_p[0, 1].detach().cpu())
+ # plot rotated query image and rotated pred
+ output_p_rot = output_logits_rot.softmax(dim=1)
+ _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
+ _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
+ plt.subplot(133)
+ plt.imshow(rotated_img[0, 0].detach().cpu())
+ plt.imshow(_pred_rot, alpha=0.5)
+ plt.savefig('debug/coarse_pred.png')
+ plt.close()
+
+ if self.coarse_pred_only:
+ output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
+ pred = output_logits.argmax(dim=1)[0]
+ conf = get_confidence_from_logits(output_logits)
+ if self.use_cca:
+ _pred = np.array(pred.detach().cpu())
+ _pred, conf = cca(_pred, output_logits, return_conf=True)
+ pred = torch.from_numpy(_pred)
+ if self.training:
+ return output_logits, [conf]
+ # Ensure pred is a float tensor for consistent visualization
+ return pred.float(), [conf]
+
+ if query_image.shape[-2:] != self.image_size:
+ query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
+ output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
+ # if need_softmax(output_logits):
+ # output_logits = output_logits.softmax(dim=1)
+
+ # output_p = output_logits
+ output_p = output_logits.softmax(dim=1)
+ pred = output_p.argmax(dim=1)[0]
+
+ _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
+ start_time = time.time()
+ if self.use_cca:
+ conn_components = cca(_pred, output_logits, return_cc=True)
+ conf=None
+ else:
+ conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
+ if self.debug:
+ plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
+ # print(f"connected components took {time.time() - start_time} seconds")
+ if _pred.max() == 0:
+ return output_p.argmax(dim=1)[0], [0]
+
+ # get bbox from pred
+ if self.use_bbox:
+ start_time = time.time()
+ try:
+ bboxes = self.get_bbox_per_cc(conn_components)
+ except:
+ bboxes = [None] * conn_components[0]
+ else:
+ bboxes = [None] * conn_components[0]
+ # print(f"getting bboxes took {time.time() - start_time} seconds")
+
+
+ start_time = time.time()
+ if self.use_points:
+ sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1)
+ else:
+ sam_input_points = [None] * conn_components[0]
+ sam_input_point_labels = [None] * conn_components[0]
+ sam_neg_input_points = [None] * conn_components[0]
+ sam_neg_input_labels = [None] * conn_components[0]
+ # print(f"getting sam input points took {time.time() - start_time} seconds")
+
+ if self.use_mask:
+ sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components)
+ else:
+ sam_input_masks = None
+ sam_input_mask_labels = None
+
+ if self.debug and sam_input_points is not None:
+ title = f'debug/most_conf_points.png'
+ if self.use_cca:
+ title = f'debug/most_conf_points_cca.png'
+ # convert points to a list where each item is a list of 2 elements in xy format
+ self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points
+
+ # self.sam_trans = None
+ if self.sam_trans is None:
+ query_image = query_image.permute(1, 2, 0).detach().cpu().numpy()
+ else:
+ query_image = self.sam_trans.apply_image_torch(query_image[0])
+ query_image = self.sam_trans.preprocess(query_image)
+ query_image = query_image.permute(1, 2, 0).detach().cpu().numpy()
+ # mask = self.sam_trans.preprocess(mask)
+
+
+ query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8)
+ if self.use_mask:
+ masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size)
+
+ start_time = time.time()
+ if self.use_points or self.use_bbox:
+ masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False)
+ # print(f"predicting w points/bbox took {time.time() - start_time} seconds")
+
+ pred = sum(masks)
+ if not self.training:
+ pred = pred > 0
+ pred = torch.tensor(pred).float().to(output_p.device)
+
+ # pred = torch.tensor(masks[0]).float().cuda()
+ # resize pred to the size of the input
+ pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
+
+ return pred, scores
+
+
+def show_mask(mask, ax, random_color=False):
+ if random_color:
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
+ else:
+ color = np.array([30/255, 144/255, 255/255, 0.6])
+ h, w = mask.shape[-2:]
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+ ax.imshow(mask_image)
+
+def show_points(coords, labels, ax, marker_size=375):
+ pos_points = coords[labels==1]
+ neg_points = coords[labels==0]
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
+
+def show_box(box, ax):
+ x0, y0 = box[0], box[1]
+ w, h = box[2] - box[0], box[3] - box[1]
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
+
+def need_softmax(tensor, dim=1):
+ return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0))
+
+
+
+
+
\ No newline at end of file
diff --git a/models/SamWrapper.py b/models/SamWrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e200c9d72b3df8b487f7137f3c5ac8f8e2239651
--- /dev/null
+++ b/models/SamWrapper.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from models.segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
+from models.segment_anything.utils.transforms import ResizeLongestSide
+import cv2
+
+def get_iou(mask, label):
+ tp = (mask * label).sum()
+ fp = (mask * (1-label)).sum()
+ fn = ((1-mask) * label).sum()
+ iou = tp / (tp + fp + fn)
+ return iou
+
+class SamWrapper(nn.Module):
+ def __init__(self,sam_args):
+ """
+ sam_args: dict should include the following
+ {
+ "model_type": "vit_h",
+ "sam_checkpoint": "path to checkpoint" pretrained_model/sam_vit_h.pth
+ }
+ """
+ super().__init__()
+ self.sam = sam_model_registry[sam_args['model_type']](checkpoint=sam_args['sam_checkpoint'])
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam)
+ self.transform = ResizeLongestSide(self.sam.image_encoder.img_size)
+
+ def forward(self, image, image_labels):
+ """
+ generate masks for a batch of images
+ return mask that has the largest iou with the image label
+ Args:
+ images (np.ndarray): The image to generate masks for, in HWC uint8 format.
+ image_labels (np.ndarray): The image labels to generate masks for, in HWC uint8 format. assuming binary labels
+ """
+ image = self.transform.apply_image(image)
+ masks = self.mask_generator.generate(image)
+
+ best_index, best_iou = None, 0
+ for i, mask in enumerate(masks):
+ segmentation = mask['segmentation']
+ iou = get_iou(segmentation.astype(np.uint8), image_labels)
+ if best_index is None or iou > best_iou:
+ best_index = i
+ best_iou = iou
+
+ return masks[best_index]['segmentation']
+
+ def to(self, device):
+ self.sam.to(device)
+ self.mask_generator.to(device)
+ self.mask_generator.predictor.to(device)
+
+
+
+if __name__ == "__main__":
+ sam_args = {
+ "model_type": "vit_h",
+ "sam_checkpoint": "pretrained_model/sam_vit_h.pth"
+ }
+ sam_wrapper = SamWrapper(sam_args).cuda()
+ image = cv2.imread("./Kheops-Pyramid.jpg")
+ image = np.array(image).astype('uint8')
+ image_labels = torch.rand(1,3,224,224)
+ sam_wrapper(image, image_labels)
+
+
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/__pycache__/ProtoSAM.cpython-312.pyc b/models/__pycache__/ProtoSAM.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c016076ac26fe6c2629c15cf243768e0808b0a4e
Binary files /dev/null and b/models/__pycache__/ProtoSAM.cpython-312.pyc differ
diff --git a/models/__pycache__/SamWrapper.cpython-312.pyc b/models/__pycache__/SamWrapper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a0f90638a9b23418f2f77b12f9bc4fcc3229f9f
Binary files /dev/null and b/models/__pycache__/SamWrapper.cpython-312.pyc differ
diff --git a/models/__pycache__/__init__.cpython-312.pyc b/models/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..477d2c8b5e4cd27e2f7a6b0f4197712c155126db
Binary files /dev/null and b/models/__pycache__/__init__.cpython-312.pyc differ
diff --git a/models/__pycache__/alpmodule.cpython-312.pyc b/models/__pycache__/alpmodule.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8deaf542b6c4429a6fe9545355868213ed10bbba
Binary files /dev/null and b/models/__pycache__/alpmodule.cpython-312.pyc differ
diff --git a/models/__pycache__/grid_proto_fewshot.cpython-312.pyc b/models/__pycache__/grid_proto_fewshot.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fbd9a441fb3a8c97ade4cfdd32433b0b31d6b64
Binary files /dev/null and b/models/__pycache__/grid_proto_fewshot.cpython-312.pyc differ
diff --git a/models/alpmodule.py b/models/alpmodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..5572bd4eebe40f663e6aee281b2583690cd442d5
--- /dev/null
+++ b/models/alpmodule.py
@@ -0,0 +1,199 @@
+"""
+ALPModule
+"""
+import torch
+import time
+import math
+from torch import nn
+from torch.nn import functional as F
+import numpy as np
+from pdb import set_trace
+import matplotlib.pyplot as plt
+# for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm
+
+def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
+ x_norm = torch.norm(x, p = p, dim = dim) # .detach()
+ x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
+ x = x.div(x_norm.unsqueeze(1).expand_as(x))
+ return x
+
+
+class MultiProtoAsConv(nn.Module):
+ def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'):
+ """
+ ALPModule
+ Args:
+ proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2
+ feature_hw: Spatial size of input feature map
+
+ """
+ super(MultiProtoAsConv, self).__init__()
+ self.feature_hw = feature_hw
+ self.proto_grid = proto_grid
+ self.upsample_mode = upsample_mode
+ kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ]
+ self.kernel_size = kernel_size
+ print(f"MultiProtoAsConv: kernel_size: {kernel_size}")
+ self.avg_pool_op = nn.AvgPool2d( kernel_size )
+
+ if use_attention:
+ self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
+ self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
+ self.fg_mask_projection = nn.Sequential(
+ nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
+ )
+ self.bg_mask_projection = nn.Sequential(
+ nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
+ )
+
+ def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ):
+ if mode == 'mask':
+ pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w]
+ # incase there are more than one prototypes in the same location, take the max
+ pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0)
+ vis_dict = {'proto_assign': pred_mask} # things to visualize
+ if vis_sim:
+ vis_dict['raw_local_sims'] = pred_mask
+ return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w]
+
+ elif mode == 'gridconv':
+ dists = F.conv2d(query, prototypes[..., None, None]) * 20
+
+ pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
+ debug_assign = dists.argmax(dim = 1).float().detach()
+
+ vis_dict = {'proto_assign': debug_assign} # things to visualize
+
+ if vis_sim: # return the similarity for visualization
+ vis_dict['raw_local_sims'] = dists.clone().detach()
+ return pred_grid, [debug_assign], vis_dict
+
+ elif mode == 'gridconv+':
+ dists = F.conv2d(query, prototypes[..., None, None]) * 20
+
+ pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
+ # raw_local_sims = dists.det ach()
+
+ debug_assign = dists.argmax(dim = 1).float()
+
+ vis_dict = {'proto_assign': debug_assign}
+ if vis_sim:
+ vis_dict['raw_local_sims'] = dists.clone().detach()
+
+ return pred_grid, [debug_assign], vis_dict
+
+ else:
+ raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.")
+
+
+ def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False):
+ if mode == 'mask':
+ proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
+ / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C
+
+ pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything
+ pro_n = proto
+ proto_grid = sup_y.clone().detach() # a single prototype for the whole image
+ resized_proto_grid = proto_grid
+ non_zero = torch.nonzero(proto_grid)
+
+ elif mode == 'gridconv':
+ nch = sup_x.shape[1]
+
+ sup_nshot = sup_x.shape[0]
+ # if len(sup_x.shape) > 4:
+ # sup_x = sup_x.squeeze()
+ n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
+ n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc
+ n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
+
+ sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
+
+ # get a grid of prototypes
+ proto_grid = sup_y_g.clone().detach()
+ proto_grid[proto_grid < thresh] = 0
+ # interpolate the grid to the original size
+ non_zero = torch.nonzero(proto_grid)
+
+ resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
+ for index in non_zero:
+ resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
+
+ sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
+ protos = n_sup_x[sup_y_g > thresh, :] # npro, nc
+ pro_n = safe_norm(protos)
+
+ elif mode == 'gridconv+':
+ nch = sup_x.shape[1]
+ n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
+ sup_nshot = sup_x.shape[0]
+ n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0)
+ n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
+ sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
+
+ # get a grid of prototypes
+ proto_grid = sup_y_g.clone().detach()
+ proto_grid[proto_grid < thresh] = 0
+ non_zero = torch.nonzero(proto_grid)
+ for i, idx in enumerate(non_zero):
+ proto_grid[0, idx[1], idx[2], idx[3]] = i + 1
+ resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
+ for index in non_zero:
+ resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
+
+ sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
+ protos = n_sup_x[sup_y_g > thresh, :]
+
+ glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
+ / (sup_y.sum(dim=(-1, -2)) + 1e-5)
+
+ pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 ))
+ return pro_n, resized_proto_grid, non_zero
+
+ def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs):
+ """
+ Now supports
+ Args:
+ mode: 'mask'/ 'grid'. if mask, works as original prototyping
+ qry: [way(1), nc, h, w]
+ sup_x: [nb, nc, h, w]
+ sup_y: [nb, 1, h, w]
+ vis_sim: visualize raw similarities or not
+ New
+ mode: 'mask'/ 'grid'. if mask, works as original prototyping
+ qry: [way(1), nb(1), nc, h, w]
+ sup_x: [way(1), shot, nb(1), nc, h, w]
+ sup_y: [way(1), shot, nb(1), h, w]
+ vis_sim: visualize raw similarities or not
+ """
+
+ qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w]
+ sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w]
+ sup_y = sup_y.squeeze(0) # [nshot, 1, h, w]
+
+ def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
+ x_norm = torch.norm(x, p = p, dim = dim) # .detach()
+ x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
+ x = x.div(x_norm.unsqueeze(1).expand_as(x))
+ return x
+ if val_wsize is None:
+ val_wsize = self.avg_pool_op.kernel_size
+ if isinstance(val_wsize, (tuple, list)):
+ val_wsize = val_wsize[0]
+ sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1])
+ pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval)
+ if 0 in pro_n.shape:
+ print("failed to find prototypes")
+ qry_n = qry if mode == 'mask' else safe_norm(qry)
+ pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim)
+
+ return pred_grid, debug_assign, vis_dict, proto_grid
+
diff --git a/models/backbone/__init__.py b/models/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/backbone/__pycache__/__init__.cpython-312.pyc b/models/backbone/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0034eb049aa1786bb4697929075ae384228280bf
Binary files /dev/null and b/models/backbone/__pycache__/__init__.cpython-312.pyc differ
diff --git a/models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc b/models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62a61bb0fe4b143084f791d6e93d9c1faea3cc57
Binary files /dev/null and b/models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc differ
diff --git a/models/backbone/torchvision_backbones.py b/models/backbone/torchvision_backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..deec7f5717816f4aa8f92a4be11fac367c92ad41
--- /dev/null
+++ b/models/backbone/torchvision_backbones.py
@@ -0,0 +1,58 @@
+"""
+Backbones supported by torchvison.
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import torchvision
+
+class TVDeeplabRes101Encoder(nn.Module):
+ """
+ FCN-Resnet101 backbone from torchvision deeplabv3
+ No ASPP is used as we found emperically it hurts performance
+ """
+ def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False):
+ super().__init__()
+ _model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None)
+ if use_coco_init:
+ print("###### NETWORK: Using ms-coco initialization ######")
+ else:
+ print("###### NETWORK: Training from scratch ######")
+
+ _model_list = list(_model.children())
+ self.aux_dim_keep = aux_dim_keep
+ self.backbone = _model_list[0]
+ self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension
+ self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False)
+
+ _aspp = _model_list[1][0]
+ _conv256 = _model_list[1][1]
+ self.aspp_out = nn.Sequential(*[_aspp, _conv256] )
+ self.use_aspp = use_aspp
+
+ def forward(self, x_in, low_level):
+ """
+ Args:
+ low_level: whether returning aggregated low-level features in FCN
+ """
+ fts = self.backbone(x_in)
+ if self.use_aspp:
+ fts256 = self.aspp_out(fts['out'])
+ high_level_fts = fts256
+ else:
+ fts2048 = fts['out']
+ high_level_fts = self.localconv(fts2048)
+
+ if low_level:
+ low_level_fts = fts['aux'][:, : self.aux_dim_keep]
+ return high_level_fts, low_level_fts
+ else:
+ return high_level_fts
+
+
+
+
+
diff --git a/models/grid_proto_fewshot.py b/models/grid_proto_fewshot.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e62bde0de706941281ef06d9c2300711462c753
--- /dev/null
+++ b/models/grid_proto_fewshot.py
@@ -0,0 +1,427 @@
+"""
+ALPNet
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .alpmodule import MultiProtoAsConv
+from .backbone.torchvision_backbones import TVDeeplabRes101Encoder
+from util.consts import DEFAULT_FEATURE_SIZE
+from util.lora import inject_trainable_lora
+# from util.utils import load_config_from_url, plot_dinov2_fts
+import math
+
+# Specify a local path to the repository (or use installed package instead)
+FG_PROT_MODE = 'gridconv+' # using both local and global prototype
+# FG_PROT_MODE = 'mask'
+# using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet)
+BG_PROT_MODE = 'gridconv'
+
+# thresholds for deciding class of prototypes
+FG_THRESH = 0.95
+BG_THRESH = 0.95
+
+
+class FewShotSeg(nn.Module):
+ """
+ ALPNet
+ Args:
+ in_channels: Number of input channels
+ cfg: Model configurations
+ """
+
+ def __init__(self, image_size, pretrained_path=None, cfg=None):
+ super(FewShotSeg, self).__init__()
+ self.image_size = image_size
+ self.pretrained_path = pretrained_path
+ print(f'###### Pre-trained path: {self.pretrained_path} ######')
+ self.config = cfg or {
+ 'align': False, 'debug': False}
+ self.get_encoder()
+ self.get_cls()
+ if self.pretrained_path:
+ self.load_state_dict(torch.load(self.pretrained_path), strict=True)
+ print(
+ f'###### Pre-trained model f{self.pretrained_path} has been loaded ######')
+
+ def get_encoder(self):
+ self.config['feature_hw'] = [DEFAULT_FEATURE_SIZE,
+ DEFAULT_FEATURE_SIZE] # default feature map size
+ if self.config['which_model'] == 'dlfcn_res101' or self.config['which_model'] == 'default':
+ use_coco_init = self.config['use_coco_init']
+ self.encoder = TVDeeplabRes101Encoder(use_coco_init)
+ self.config['feature_hw'] = [
+ math.ceil(self.image_size/8), math.ceil(self.image_size/8)]
+ elif self.config['which_model'] == 'dinov2_l14':
+ self.encoder = torch.hub.load(
+ 'facebookresearch/dinov2', 'dinov2_vitl14')
+ self.config['feature_hw'] = [max(
+ self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
+ elif self.config['which_model'] == 'dinov2_l14_reg':
+ try:
+ self.encoder = torch.hub.load(
+ 'facebookresearch/dinov2', 'dinov2_vitl14_reg')
+ except RuntimeError as e:
+ self.encoder = torch.hub.load(
+ 'facebookresearch/dino', 'dinov2_vitl14_reg', force_reload=True)
+ self.config['feature_hw'] = [max(
+ self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
+ elif self.config['which_model'] == 'dinov2_b14':
+ self.encoder = torch.hub.load(
+ 'facebookresearch/dinov2', 'dinov2_vitb14')
+ self.config['feature_hw'] = [max(
+ self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
+ else:
+ raise NotImplementedError(
+ f'Backbone network {self.config["which_model"]} not implemented')
+
+ if self.config['lora'] > 0:
+ self.encoder.requires_grad_(False)
+ print(f'Injecting LoRA with rank:{self.config["lora"]}')
+ encoder_lora_params = inject_trainable_lora(
+ self.encoder, r=self.config['lora'])
+
+ def get_features(self, imgs_concat):
+ if self.config['which_model'] == 'dlfcn_res101':
+ img_fts = self.encoder(imgs_concat, low_level=False)
+ elif 'dino' in self.config['which_model']:
+ # resize imgs_concat to the closest size that is divisble by 14
+ imgs_concat = F.interpolate(imgs_concat, size=(
+ self.image_size // 14 * 14, self.image_size // 14 * 14), mode='bilinear')
+ dino_fts = self.encoder.forward_features(imgs_concat)
+ img_fts = dino_fts["x_norm_patchtokens"] # B, HW, C
+ img_fts = img_fts.permute(0, 2, 1) # B, C, HW
+ C, HW = img_fts.shape[-2:]
+ img_fts = img_fts.view(-1, C, int(HW**0.5),
+ int(HW**0.5)) # B, C, H, W
+ if HW < DEFAULT_FEATURE_SIZE ** 2:
+ img_fts = F.interpolate(img_fts, size=(
+ DEFAULT_FEATURE_SIZE, DEFAULT_FEATURE_SIZE), mode='bilinear') # this is if h,w < (32,32)
+ else:
+ raise NotImplementedError(
+ f'Backbone network {self.config["which_model"]} not implemented')
+
+ return img_fts
+
+ def get_cls(self):
+ """
+ Obtain the similarity-based classifier
+ """
+ proto_hw = self.config["proto_grid_size"]
+
+ if self.config['cls_name'] == 'grid_proto':
+ embed_dim = 256
+ if 'dinov2_b14' in self.config['which_model']:
+ embed_dim = 768
+ elif 'dinov2_l14' in self.config['which_model']:
+ embed_dim = 1024
+ self.cls_unit = MultiProtoAsConv(proto_grid=[proto_hw, proto_hw], feature_hw=self.config["feature_hw"], embed_dim=embed_dim) # when treating it as ordinary prototype
+ print(f"cls unit feature hw: {self.cls_unit.feature_hw}")
+ else:
+ raise NotImplementedError(
+ f'Classifier {self.config["cls_name"]} not implemented')
+
+ def forward_resolutions(self, resolutions, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
+ predictions = []
+ for res in resolutions:
+ supp_imgs_resized = [[F.interpolate(supp_img[0], size=(
+ res, res), mode='bilinear') for supp_img in supp_imgs]] if supp_imgs[0][0].shape[-1] != res else supp_imgs
+ fore_mask_resized = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
+ 0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != res else fore_mask
+ back_mask_resized = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
+ 0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != res else back_mask
+ qry_imgs_resized = [F.interpolate(qry_img, size=(res, res), mode='bilinear')
+ for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != res else qry_imgs
+
+ pred = self.forward(supp_imgs_resized, fore_mask_resized, back_mask_resized,
+ qry_imgs_resized, isval, val_wsize, show_viz, supp_fts)[0]
+ predictions.append(pred)
+
+ def resize_inputs_to_image_size(self, supp_imgs, fore_mask, back_mask, qry_imgs):
+ supp_imgs = [[F.interpolate(supp_img, size=(
+ self.image_size, self.image_size), mode='bilinear') for supp_img in supp_imgs_way] for supp_imgs_way in supp_imgs]
+ fore_mask = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
+ 0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != self.image_size else fore_mask
+ back_mask = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
+ 0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != self.image_size else back_mask
+ qry_imgs = [F.interpolate(qry_img, size=(self.image_size, self.image_size), mode='bilinear')
+ for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != self.image_size else qry_imgs
+ return supp_imgs, fore_mask, back_mask, qry_imgs
+
+ def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
+ """
+ Args:
+ supp_imgs: support images
+ way x shot x [B x 3 x H x W], list of lists of tensors
+ fore_mask: foreground masks for support images
+ way x shot x [B x H x W], list of lists of tensors
+ back_mask: background masks for support images
+ way x shot x [B x H x W], list of lists of tensors
+ qry_imgs: query images
+ N x [B x 3 x H x W], list of tensors
+ show_viz: return the visualization dictionary
+ """
+ # ('Please go through this piece of code carefully')
+ # supp_imgs, fore_mask, back_mask, qry_imgs = self.resize_inputs_to_image_size(
+ # supp_imgs, fore_mask, back_mask, qry_imgs)
+
+ n_ways = len(supp_imgs)
+ n_shots = len(supp_imgs[0])
+ n_queries = len(qry_imgs)
+
+ # NOTE: actual shot in support goes in batch dimension
+ assert n_ways == 1, "Multi-shot has not been implemented yet"
+ assert n_queries == 1
+
+ sup_bsize = supp_imgs[0][0].shape[0]
+ img_size = supp_imgs[0][0].shape[-2:]
+ if self.config["cls_name"] == 'grid_proto_3d':
+ img_size = supp_imgs[0][0].shape[-3:]
+ qry_bsize = qry_imgs[0].shape[0]
+
+ imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs]
+ + [torch.cat(qry_imgs, dim=0),], dim=0)
+
+ img_fts = self.get_features(imgs_concat)
+ if len(img_fts.shape) == 5: # for 3D
+ fts_size = img_fts.shape[-3:]
+ else:
+ fts_size = img_fts.shape[-2:]
+ if supp_fts is None:
+ supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view(
+ n_ways, n_shots, sup_bsize, -1, *fts_size) # wa x sh x b x c x h' x w'
+ qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view(
+ n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W'
+ else:
+ # N x B x C x H' x W'
+ qry_fts = img_fts.view(n_queries, qry_bsize, -1, *fts_size)
+
+ fore_mask = torch.stack([torch.stack(way, dim=0)
+ for way in fore_mask], dim=0) # Wa x Sh x B x H' x W'
+ fore_mask = torch.autograd.Variable(fore_mask, requires_grad=True)
+ back_mask = torch.stack([torch.stack(way, dim=0)
+ for way in back_mask], dim=0) # Wa x Sh x B x H' x W'
+
+ ###### Compute loss ######
+ align_loss = 0
+ outputs = []
+ visualizes = [] # the buffer for visualization
+
+ for epi in range(1): # batch dimension, fixed to 1
+ fg_masks = [] # keep the way part
+
+ '''
+ for way in range(n_ways):
+ # note: index of n_ways starts from 0
+ mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this
+ mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear')
+ fg_masks.append( mean_sup_msk )
+
+ mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W]
+ '''
+ # re-interpolate support mask to the same size as support feature
+ if len(fts_size) == 3: # TODO make more generic
+ res_fg_msk = torch.stack([F.interpolate(fore_mask[0][0].unsqueeze(
+ 0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
+ res_bg_msk = torch.stack([F.interpolate(back_mask[0][0].unsqueeze(
+ 0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
+ else:
+ res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size=fts_size, mode='nearest')
+ for fore_mask_w in fore_mask], dim=0) # [nway, ns, nb, nh', nw']
+ res_bg_msk = torch.stack([F.interpolate(back_mask_w, size=fts_size, mode='nearest')
+ for back_mask_w in back_mask], dim=0) # [nway, ns, nb, nh', nw']
+
+ scores = []
+ assign_maps = []
+ bg_sim_maps = []
+ fg_sim_maps = []
+ bg_mode = BG_PROT_MODE
+
+ _raw_score, _, aux_attr, _ = self.cls_unit(
+ qry_fts, supp_fts, res_bg_msk, mode=bg_mode, thresh=BG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
+ scores.append(_raw_score)
+ assign_maps.append(aux_attr['proto_assign'])
+
+ for way, _msks in enumerate(res_fg_msk):
+ raw_scores = []
+ for i, _msk in enumerate(_msks):
+ _msk = _msk.unsqueeze(0)
+ supp_ft = supp_fts[:, i].unsqueeze(0)
+ if self.config["cls_name"] == 'grid_proto_3d': # 3D
+ k_size = self.cls_unit.kernel_size
+ fg_mode = FG_PROT_MODE if F.avg_pool3d(_msk, k_size).max(
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' # TODO figure out kernel size
+ else:
+ k_size = self.cls_unit.kernel_size
+ fg_mode = FG_PROT_MODE if F.avg_pool2d(_msk, k_size).max(
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
+ # TODO figure out kernel size
+ _raw_score, _, aux_attr, proto_grid = self.cls_unit(qry_fts, supp_ft, _msk.unsqueeze(
+ 0), mode=fg_mode, thresh=FG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
+ raw_scores.append(_raw_score)
+
+ # create a score where each feature is the max of the raw_score
+ _raw_score = torch.stack(raw_scores, dim=1).max(dim=1)[
+ 0]
+ scores.append(_raw_score)
+ assign_maps.append(aux_attr['proto_assign'])
+ if show_viz:
+ fg_sim_maps.append(aux_attr['raw_local_sims'])
+ # print(f"Time for fg: {time.time() - start_time}")
+ pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
+ interpolate_mode = 'bilinear'
+ outputs.append(F.interpolate(
+ pred, size=img_size, mode=interpolate_mode))
+
+ ###### Prototype alignment loss ######
+ if self.config['align'] and self.training:
+ align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi],
+ fore_mask[:, :, epi], back_mask[:, :, epi])
+ align_loss += align_loss_epi
+
+ output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W
+ grid_shape = output.shape[2:]
+ if self.config["cls_name"] == 'grid_proto_3d':
+ grid_shape = output.shape[2:]
+ output = output.view(-1, *grid_shape)
+ assign_maps = torch.stack(assign_maps, dim=1) if show_viz else None
+ bg_sim_maps = torch.stack(bg_sim_maps, dim=1) if show_viz else None
+ fg_sim_maps = torch.stack(fg_sim_maps, dim=1) if show_viz else None
+
+ return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps, proto_grid, supp_fts, qry_fts
+
+
+ def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask):
+ """
+ Compute the loss for the prototype alignment branch
+
+ Args:
+ qry_fts: embedding features for query images
+ expect shape: N x C x H' x W'
+ pred: predicted segmentation score
+ expect shape: N x (1 + Wa) x H x W
+ supp_fts: embedding fatures for support images
+ expect shape: Wa x Sh x C x H' x W'
+ fore_mask: foreground masks for support images
+ expect shape: way x shot x H x W
+ back_mask: background masks for support images
+ expect shape: way x shot x H x W
+ """
+ n_ways, n_shots = len(fore_mask), len(fore_mask[0])
+
+ # Masks for getting query prototype
+ pred_mask = pred.argmax(dim=1).unsqueeze(0) # 1 x N x H' x W'
+ binary_masks = [pred_mask == i for i in range(1 + n_ways)]
+
+ # skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
+ # FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness
+ skip_ways = []
+
+ # added for matching dimensions to the new data format
+ qry_fts = qry_fts.unsqueeze(0).unsqueeze(
+ 2) # added to nway(1) and nb(1)
+ # end of added part
+
+ loss = []
+ for way in range(n_ways):
+ if way in skip_ways:
+ continue
+ # Get the query prototypes
+ for shot in range(n_shots):
+ # actual local query [way(1), nb(1, nb is now nshot), nc, h, w]
+ img_fts = supp_fts[way: way + 1, shot: shot + 1]
+ size = img_fts.shape[-2:]
+ mode = 'bilinear'
+ if self.config["cls_name"] == 'grid_proto_3d':
+ size = img_fts.shape[-3:]
+ mode = 'trilinear'
+ qry_pred_fg_msk = F.interpolate(
+ binary_masks[way + 1].float(), size=size, mode=mode) # [1 (way), n (shot), h, w]
+
+ # background
+ qry_pred_bg_msk = F.interpolate(
+ binary_masks[0].float(), size=size, mode=mode) # 1, n, h ,w
+ scores = []
+
+ bg_mode = BG_PROT_MODE
+ _raw_score_bg, _, _, _ = self.cls_unit(
+ qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_bg_msk.unsqueeze(-3), mode=bg_mode, thresh=BG_THRESH)
+
+ scores.append(_raw_score_bg)
+ if self.config["cls_name"] == 'grid_proto_3d':
+ fg_mode = FG_PROT_MODE if F.avg_pool3d(qry_pred_fg_msk, 4).max(
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
+ else:
+ fg_mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max(
+ ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
+ _raw_score_fg, _, _, _ = self.cls_unit(
+ qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_fg_msk.unsqueeze(2), mode=fg_mode, thresh=FG_THRESH)
+ scores.append(_raw_score_fg)
+
+ supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
+ size = fore_mask.shape[-2:]
+ if self.config["cls_name"] == 'grid_proto_3d':
+ size = fore_mask.shape[-3:]
+ supp_pred = F.interpolate(supp_pred, size=size, mode=mode)
+
+ # Construct the support Ground-Truth segmentation
+ supp_label = torch.full_like(fore_mask[way, shot], 255,
+ device=img_fts.device).long()
+ supp_label[fore_mask[way, shot] == 1] = 1
+ supp_label[back_mask[way, shot] == 1] = 0
+ # Compute Loss
+ loss.append(F.cross_entropy(
+ supp_pred.float(), supp_label[None, ...], ignore_index=255) / n_shots / n_ways)
+
+ return torch.sum(torch.stack(loss))
+
+ def dino_cls_loss(self, teacher_cls_tokens, student_cls_tokens):
+ cls_loss_weight = 0.1
+ student_temp = 1
+ teacher_cls_tokens = self.sinkhorn_knopp_teacher(teacher_cls_tokens)
+ lsm = F.log_softmax(student_cls_tokens / student_temp, dim=-1)
+ cls_loss = torch.sum(teacher_cls_tokens * lsm, dim=-1)
+
+ return -cls_loss.mean() * cls_loss_weight
+
+ @torch.no_grad()
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp=1, n_iterations=3):
+ teacher_output = teacher_output.float()
+ # world_size = dist.get_world_size() if dist.is_initialized() else 1
+ # Q is K-by-B for consistency with notations from our paper
+ Q = torch.exp(teacher_output / teacher_temp).t()
+ # B = Q.shape[1] * world_size # number of samples to assign
+ B = Q.shape[1]
+ K = Q.shape[0] # how many prototypes
+
+ # make the matrix sums to 1
+ sum_Q = torch.sum(Q)
+ Q /= sum_Q
+
+ for it in range(n_iterations):
+ # normalize each row: total weight per prototype must be 1/K
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
+ Q /= sum_of_rows
+ Q /= K
+
+ # normalize each column: total weight per sample must be 1/B
+ Q /= torch.sum(Q, dim=0, keepdim=True)
+ Q /= B
+
+ Q *= B # the columns must sum to 1 so that Q is an assignment
+ return Q.t()
+
+ def dino_patch_loss(self, features, masked_features, masks):
+ # for both supp and query features perform the patch wise loss
+ loss = 0.0
+ weight = 0.1
+ B = features.shape[0]
+ for (f, mf, mask) in zip(features, masked_features, masks):
+ # TODO sinkhorn knopp center features
+ f = f[mask]
+ f = self.sinkhorn_knopp_teacher(f)
+ mf = mf[mask]
+ loss += torch.sum(f * F.log_softmax(mf / 1,
+ dim=-1), dim=-1) / mask.sum()
+
+ return -loss.sum() * weight / B
diff --git a/models/segment_anything/__init__.py b/models/segment_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e6d84e0e9041f4e72e2fab4d01a38744fcd571
--- /dev/null
+++ b/models/segment_anything/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .build_sam import (
+ build_sam,
+ build_sam_vit_h,
+ build_sam_vit_l,
+ build_sam_vit_b,
+ sam_model_registry,
+)
+from .predictor import SamPredictor
+from .automatic_mask_generator import SamAutomaticMaskGenerator
diff --git a/models/segment_anything/__pycache__/__init__.cpython-312.pyc b/models/segment_anything/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cf08cb5fa1c8c737a4536ea809e560d6eeec88f
Binary files /dev/null and b/models/segment_anything/__pycache__/__init__.cpython-312.pyc differ
diff --git a/models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc b/models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e655b3a065480bcfd6b6de7b635bc41e92d284c
Binary files /dev/null and b/models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc differ
diff --git a/models/segment_anything/__pycache__/build_sam.cpython-312.pyc b/models/segment_anything/__pycache__/build_sam.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84c991cb724e26b739b4db42bb7b3d8102d1bd14
Binary files /dev/null and b/models/segment_anything/__pycache__/build_sam.cpython-312.pyc differ
diff --git a/models/segment_anything/__pycache__/predictor.cpython-312.pyc b/models/segment_anything/__pycache__/predictor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9a754111fbc80dbcb37f34b21b0d85d259ecb91
Binary files /dev/null and b/models/segment_anything/__pycache__/predictor.cpython-312.pyc differ
diff --git a/models/segment_anything/automatic_mask_generator.py b/models/segment_anything/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e8a52ae0500004786f1731820a9fbaaea56820
--- /dev/null
+++ b/models/segment_anything/automatic_mask_generator.py
@@ -0,0 +1,380 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from .modeling import Sam
+from .predictor import SamPredictor
+from .utils.amg import (
+ MaskData,
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SamAutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: Sam,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ custom_points: bool = "false",
+ ) -> None:
+ """
+ Using a SAM model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM with a ViT-H backbone.
+
+ Arguments:
+ model (Sam): The SAM model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crop_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = SamPredictor(model)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+ self.custom_points = custom_points
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Filter small disconnected regions and holes in masks
+ if self.min_mask_region_area > 0:
+ mask_data = self.postprocess_small_regions(
+ mask_data,
+ self.min_mask_region_area,
+ max(self.box_nms_thresh, self.crop_nms_thresh),
+ )
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_image()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
+ if self.custom_points:
+ in_pos_labels = torch.ones(in_points.shape[0]//2, dtype=torch.int, device=in_points.device)
+ in_neg_labels = torch.zeros_like(in_pos_labels)
+ in_labels = torch.cat((in_pos_labels, in_neg_labels), dim = 0)
+ else:
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+
+ masks, iou_preds, _ = self.predictor.predict_torch(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=True,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+ )
+ del masks
+
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros_like(boxes[:, 0]), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
diff --git a/models/segment_anything/build_sam.py b/models/segment_anything/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc353340ca4fd5c29db7dfae3494ffb4c4c9502
--- /dev/null
+++ b/models/segment_anything/build_sam.py
@@ -0,0 +1,107 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, SamBatched
+
+
+def build_sam_vit_h(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1280,
+ encoder_depth=32,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[7, 15, 23, 31],
+ checkpoint=checkpoint,
+ )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1024,
+ encoder_depth=24,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[5, 11, 17, 23],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam_vit_b(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=768,
+ encoder_depth=12,
+ encoder_num_heads=12,
+ encoder_global_attn_indexes=[2, 5, 8, 11],
+ checkpoint=checkpoint,
+ )
+
+
+sam_model_registry = {
+ "default": build_sam_vit_h,
+ "vit_h": build_sam_vit_h,
+ "vit_l": build_sam_vit_l,
+ "vit_b": build_sam_vit_b,
+}
+
+
+def _build_sam(
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ encoder_global_attn_indexes,
+ checkpoint=None,
+):
+ prompt_embed_dim = 256
+ image_size = 1024
+ vit_patch_size = 16
+ image_embedding_size = image_size // vit_patch_size
+ sam = SamBatched(
+ image_encoder=ImageEncoderViT(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=encoder_num_heads,
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ ),
+ prompt_encoder=PromptEncoder(
+ embed_dim=prompt_embed_dim,
+ image_embedding_size=(image_embedding_size, image_embedding_size),
+ input_image_size=(image_size, image_size),
+ mask_in_chans=16,
+ ),
+ mask_decoder=MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ ),
+ pixel_mean=[123.675, 116.28, 103.53],
+ pixel_std=[58.395, 57.12, 57.375],
+ )
+ sam.eval()
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ sam.load_state_dict(state_dict)
+ return sam
diff --git a/models/segment_anything/modeling/__init__.py b/models/segment_anything/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a25e343d958d3b60feb84b92f6948ffd1f2a21dd
--- /dev/null
+++ b/models/segment_anything/modeling/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sam import Sam, SamBatched
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .transformer import TwoWayTransformer
diff --git a/models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3833dcb89122f9a2b8d42f419928146c740dba27
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/__pycache__/common.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47894cadaf295efbb99951120dfd18e384fd72f3
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/common.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d43446357b93f8ae0dda00a4b4a4a60e06475edb
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb2d3a49dfc211011182746852d5c6a45d819a4b
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32500d9dcf256a0f9b170d4fb09dc5fad1352914
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8900c44916b52e74deed3305dfd151788f636e0f
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ab586ee2f5271ce8faf49ad244e31463ec1ea77
Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc differ
diff --git a/models/segment_anything/modeling/common.py b/models/segment_anything/modeling/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c92073d1fd6a44d9a7f3abb9ab610d3ccbcac12
--- /dev/null
+++ b/models/segment_anything/modeling/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/models/segment_anything/modeling/image_encoder.py b/models/segment_anything/modeling/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..224bffa56aad613c52ef8b2b469c1f518c05f7e3
--- /dev/null
+++ b/models/segment_anything/modeling/image_encoder.py
@@ -0,0 +1,406 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+
+from .common import LayerNorm2d, MLPBlock
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
+ use_grad_checkpointing: bool = False,
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_attn_indexes (list): Indexes for blocks using global attention.
+ """
+ super().__init__()
+ self.img_size = img_size
+ self.use_grad_checkpointing = use_grad_checkpointing
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ nn.Conv2d(
+ out_chans,
+ out_chans,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ if self.use_grad_checkpointing:
+ blk.use_grad_checkpointing = True
+ x = torch.utils.checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ x = self.neck(x.permute(0, 3, 1, 2))
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ use_grad_checkpointing: bool = False,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then
+ use global attention.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.use_grad_checkpointing = use_grad_checkpointing
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ if self.use_grad_checkpointing:
+ x = torch.utils.checkpoint.checkpoint(self.attn, x)
+ else:
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert (
+ input_size is not None
+ ), "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ x = self.proj(x)
+
+ return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ ).view(B, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/models/segment_anything/modeling/mask_decoder.py b/models/segment_anything/modeling/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a4fdb868e1b0340d1bb6b1ee84a20eca27be455
--- /dev/null
+++ b/models/segment_anything/modeling/mask_decoder.py
@@ -0,0 +1,176 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+ )
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ """
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, mask_slice, :, :]
+ iou_pred = iou_pred[:, mask_slice]
+
+ # Prepare output
+ return masks, iou_pred
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ upscaled_embedding = self.output_upscaling(src)
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/models/segment_anything/modeling/prompt_encoder.py b/models/segment_anything/modeling/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f73520ad1318da91f271a623c8497c8b9a31475
--- /dev/null
+++ b/models/segment_anything/modeling/prompt_encoder.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn
+
+from typing import Any, Optional, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
diff --git a/models/segment_anything/modeling/sam.py b/models/segment_anything/modeling/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..aac8cce7babf015a7ae3356e092aa081bcfa076f
--- /dev/null
+++ b/models/segment_anything/modeling/sam.py
@@ -0,0 +1,333 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import Any, Dict, List, Tuple
+
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+class Sam(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ image_encoder (ImageEncoderViT): The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ # @torch.no_grad()
+ def forward(
+ self,
+ batched_input: List[Dict[str, Any]],
+ multimask_output: bool,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_input (list(dict)): A list over input images, each a
+ dictionary with the following keys. A prompt key can be
+ excluded if it is not present.
+ 'image': The image as a torch tensor in 3xHxW format,
+ already transformed for input to the model.
+ 'original_size': (tuple(int, int)) The original size of
+ the image before transformation, as (H, W).
+ 'point_coords': (torch.Tensor) Batched point prompts for
+ this image, with shape BxNx2. Already transformed to the
+ input frame of the model.
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
+ with shape BxN.
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+ Already transformed to the input frame of the model.
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+ in the form Bx1xHxW.
+ multimask_output (bool): Whether the model should predict multiple
+ disambiguating masks, or return a single mask.
+
+ Returns:
+ (list(dict)): A list over input images, where each element is
+ as dictionary with the following keys.
+ 'masks': (torch.Tensor) Batched binary mask predictions,
+ with shape BxCxHxW, where B is the number of input prompts,
+ C is determined by multimask_output, and (H, W) is the
+ original size of the image.
+ 'iou_predictions': (torch.Tensor) The model's predictions
+ of mask quality, in shape BxC.
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
+ to subsequent iterations of prediction.
+ """
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+ image_embeddings = self.image_encoder(input_images)
+
+ outputs = []
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
+ if "point_coords" in image_record:
+ points = (image_record["point_coords"], image_record["point_labels"])
+ else:
+ points = None
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ points=points,
+ boxes=image_record.get("boxes", None),
+ masks=image_record.get("mask_inputs", None),
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=curr_embedding.unsqueeze(0),
+ image_pe=self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+ masks = self.postprocess_masks(
+ low_res_masks,
+ input_size=image_record["image"].shape[-2:],
+ original_size=image_record["original_size"],
+ )
+ masks = masks > self.mask_threshold
+ outputs.append(
+ {
+ "masks": masks,
+ "iou_predictions": iou_predictions,
+ "low_res_logits": low_res_masks,
+ }
+ )
+ return outputs
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: Tuple[int, ...],
+ original_size: Tuple[int, ...],
+ ) -> torch.Tensor:
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the mask_decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ (self.image_encoder.img_size, self.image_encoder.img_size),
+ mode="nearest"
+ )
+ masks = masks[..., : input_size[0], : input_size[1]]
+ masks = F.interpolate(masks, original_size, mode="nearest")
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.image_encoder.img_size - h
+ padw = self.image_encoder.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+
+class SamBatched(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ image_encoder (ImageEncoderViT): The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ # @torch.no_grad()
+ def forward(
+ self,
+ batched_input: torch.Tensor,
+ multimask_output: bool,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_input (list(dict)): A list over input images, each a
+ dictionary with the following keys. A prompt key can be
+ excluded if it is not present.
+ 'image': The image as a torch tensor in 3xHxW format,
+ already transformed for input to the model.
+ 'original_size': (tuple(int, int)) The original size of
+ the image before transformation, as (H, W).
+ 'point_coords': (torch.Tensor) Batched point prompts for
+ this image, with shape BxNx2. Already transformed to the
+ input frame of the model.
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
+ with shape BxN.
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+ Already transformed to the input frame of the model.
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+ in the form Bx1xHxW.
+ multimask_output (bool): Whether the model should predict multiple
+ disambiguating masks, or return a single mask.
+
+ Returns:
+ (list(dict)): A list over input images, where each element is
+ as dictionary with the following keys.
+ 'masks': (torch.Tensor) Batched binary mask predictions,
+ with shape BxCxHxW, where B is the number of input prompts,
+ C is determined by multimask_output, and (H, W) is the
+ original size of the image.
+ 'iou_predictions': (torch.Tensor) The model's predictions
+ of mask quality, in shape BxC.
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
+ to subsequent iterations of prediction.
+ """
+ with torch.no_grad():
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+ image_embeddings = self.image_encoder(input_images)
+
+ outputs = []
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
+ if "point_coords" in image_record:
+ points = (image_record["point_coords"], image_record["point_labels"])
+ else:
+ points = None
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ points=points,
+ boxes=image_record.get("boxes", None),
+ masks=image_record.get("mask_inputs", None),
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=curr_embedding.unsqueeze(0),
+ image_pe=self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+ masks = self.postprocess_masks(
+ low_res_masks,
+ input_size=image_record["image_size"],
+ original_size=image_record["original_size"],
+ )
+ masks = masks > self.mask_threshold
+ outputs.append(
+ {
+ "masks": masks,
+ "iou_predictions": iou_predictions,
+ "low_res_logits": low_res_masks,
+ }
+ )
+ return outputs
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: Tuple[int, ...],
+ original_size: Tuple[int, ...],
+ ) -> torch.Tensor:
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the mask_decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ (self.image_encoder.img_size, self.image_encoder.img_size),
+ mode="bilinear",
+ align_corners=True,
+ )
+ masks = masks[..., : int(input_size[0]), : int(input_size[1])]
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=True)
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.image_encoder.img_size - h
+ padw = self.image_encoder.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
\ No newline at end of file
diff --git a/models/segment_anything/modeling/transformer.py b/models/segment_anything/modeling/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d99f8e8265b5780dd3be1d8c6bbd33156ac1d8f4
--- /dev/null
+++ b/models/segment_anything/modeling/transformer.py
@@ -0,0 +1,240 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+
+import math
+from typing import Tuple, Type
+
+from .common import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/models/segment_anything/predictor.py b/models/segment_anything/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..dadf27ea8e9962418f7714d08ebfadcf9c4d3182
--- /dev/null
+++ b/models/segment_anything/predictor.py
@@ -0,0 +1,269 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+from segment_anything.modeling import Sam
+
+from typing import Optional, Tuple
+
+from .utils.transforms import ResizeLongestSide
+
+
+class SamPredictor:
+ def __init__(
+ self,
+ sam_model: Sam,
+ ) -> None:
+ """
+ Uses SAM to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam): The model to use for mask prediction.
+ """
+ super().__init__()
+ self.model = sam_model
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
+ self.reset_image()
+
+ def set_image(
+ self,
+ image: np.ndarray,
+ image_format: str = "RGB",
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray): The image for calculating masks. Expects an
+ image in HWC uint8 format, with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ assert image_format in [
+ "RGB",
+ "BGR",
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ # Transform the image to the form expected by the model
+ input_image = self.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+ self.set_torch_image(input_image_torch, image.shape[:2])
+
+ @torch.no_grad()
+ def set_torch_image(
+ self,
+ transformed_image: torch.Tensor,
+ original_image_size: Tuple[int, ...],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method. Expects the input
+ image to be already transformed to the format expected by the model.
+
+ Arguments:
+ transformed_image (torch.Tensor): The input image, with shape
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
+ original_image_size (tuple(int, int)): The size of the image
+ before transformation, in (H, W) format.
+ """
+ assert (
+ len(transformed_image.shape) == 4
+ and transformed_image.shape[1] == 3
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
+ self.reset_image()
+
+ self.original_size = original_image_size
+ self.input_size = tuple(transformed_image.shape[-2:])
+ input_image = self.model.preprocess(transformed_image)
+ self.features = self.model.image_encoder(input_image)
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ # Transform input prompts
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ box_torch = box_torch[None, :]
+ if mask_input is not None:
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
+ mask_input_torch = mask_input_torch[None, :, :, :]
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ mask_input_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks_np = masks[0].detach().cpu().numpy()
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+ @torch.no_grad()
+ def predict_torch(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ else:
+ points = None
+
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+ points=points,
+ boxes=boxes,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ low_res_masks, iou_predictions = self.model.mask_decoder(
+ image_embeddings=self.features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
+
+ if not return_logits:
+ masks = masks > self.model.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert self.features is not None, "Features must exist if an image has been set."
+ return self.features
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_image(self) -> None:
+ """Resets the currently set image."""
+ self.is_image_set = False
+ self.features = None
+ self.orig_h = None
+ self.orig_w = None
+ self.input_h = None
+ self.input_w = None
diff --git a/models/segment_anything/utils/__init__.py b/models/segment_anything/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/models/segment_anything/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/models/segment_anything/utils/__pycache__/__init__.cpython-312.pyc b/models/segment_anything/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08aafbdaeb0c6dec7e8adb88465ed99911a5e238
Binary files /dev/null and b/models/segment_anything/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/models/segment_anything/utils/__pycache__/amg.cpython-312.pyc b/models/segment_anything/utils/__pycache__/amg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b23fbac1c6e6bbfcffec57c790994544e96139c
Binary files /dev/null and b/models/segment_anything/utils/__pycache__/amg.cpython-312.pyc differ
diff --git a/models/segment_anything/utils/__pycache__/transforms.cpython-312.pyc b/models/segment_anything/utils/__pycache__/transforms.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ced6b7a813dc9fed97c230f180afe7620eec2ca
Binary files /dev/null and b/models/segment_anything/utils/__pycache__/transforms.cpython-312.pyc differ
diff --git a/models/segment_anything/utils/amg.py b/models/segment_anything/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1464be421bbf473e0131e1be1c40e869b060a86d
--- /dev/null
+++ b/models/segment_anything/utils/amg.py
@@ -0,0 +1,346 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/models/segment_anything/utils/onnx.py b/models/segment_anything/utils/onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9a9d9e2f1c5990f6b279ef7d1bb847063c68e5e
--- /dev/null
+++ b/models/segment_anything/utils/onnx.py
@@ -0,0 +1,144 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from typing import Tuple
+
+from ..modeling import Sam
+from .amg import calculate_stability_score
+
+
+class SamOnnxModel(nn.Module):
+ """
+ This model should not be called directly, but is used in ONNX export.
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
+ with some functions modified to enable model tracing. Also supports extra
+ options controlling what information. See the ONNX export script for details.
+ """
+
+ def __init__(
+ self,
+ model: Sam,
+ return_single_mask: bool,
+ use_stability_score: bool = False,
+ return_extra_metrics: bool = False,
+ ) -> None:
+ super().__init__()
+ self.mask_decoder = model.mask_decoder
+ self.model = model
+ self.img_size = model.image_encoder.img_size
+ self.return_single_mask = return_single_mask
+ self.use_stability_score = use_stability_score
+ self.stability_score_offset = 1.0
+ self.return_extra_metrics = return_extra_metrics
+
+ @staticmethod
+ def resize_longest_image_size(
+ input_image_size: torch.Tensor, longest_side: int
+ ) -> torch.Tensor:
+ input_image_size = input_image_size.to(torch.float32)
+ scale = longest_side / torch.max(input_image_size)
+ transformed_size = scale * input_image_size
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
+ return transformed_size
+
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+ point_coords = point_coords + 0.5
+ point_coords = point_coords / self.img_size
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
+
+ point_embedding = point_embedding * (point_labels != -1)
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
+ point_labels == -1
+ )
+
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
+ i
+ ].weight * (point_labels == i)
+
+ return point_embedding
+
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+ mask_embedding = mask_embedding + (
+ 1 - has_mask_input
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
+ return mask_embedding
+
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+ masks = F.interpolate(
+ masks,
+ size=(self.img_size, self.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
+
+ orig_im_size = orig_im_size.to(torch.int64)
+ h, w = orig_im_size[0], orig_im_size[1]
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
+ return masks
+
+ def select_masks(
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Determine if we should return the multiclick mask or not from the number of points.
+ # The reweighting is used to avoid control flow.
+ score_reweight = torch.tensor(
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
+ ).to(iou_preds.device)
+ score = iou_preds + (num_points - 2.5) * score_reweight
+ best_idx = torch.argmax(score, dim=1)
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
+
+ return masks, iou_preds
+
+ @torch.no_grad()
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ mask_input: torch.Tensor,
+ has_mask_input: torch.Tensor,
+ orig_im_size: torch.Tensor,
+ ):
+ sparse_embedding = self._embed_points(point_coords, point_labels)
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
+
+ masks, scores = self.model.mask_decoder.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embedding,
+ dense_prompt_embeddings=dense_embedding,
+ )
+
+ if self.use_stability_score:
+ scores = calculate_stability_score(
+ masks, self.model.mask_threshold, self.stability_score_offset
+ )
+
+ if self.return_single_mask:
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
+
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
+
+ if self.return_extra_metrics:
+ stability_scores = calculate_stability_score(
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
+ )
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
+ return upscaled_masks, scores, stability_scores, areas, masks
+
+ return upscaled_masks, scores, masks
diff --git a/models/segment_anything/utils/transforms.py b/models/segment_anything/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..165d19be566815164dfb37d10d08e21d369d34dd
--- /dev/null
+++ b/models/segment_anything/utils/transforms.py
@@ -0,0 +1,148 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image # type: ignore
+from typing import List
+
+from copy import deepcopy
+from typing import Tuple
+
+
+class ResizeLongestSide:
+ """
+ Resizes images to the longest side 'target_length', as well as provides
+ methods for resizing coordinates and boxes. Provides methods for
+ transforming both numpy array and batched torch tensors.
+ """
+
+ def __init__(self, target_length: int,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],) -> None:
+
+ self.target_length = target_length
+ self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
+ self.pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
+
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return np.array(resize(to_pil_image(image), target_size))
+
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array shape Bx4. Requires the original image size
+ in (H, W) format.
+ """
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Expects batched images with shape BxCxHxW and float format. This
+ transformation may not exactly match apply_image. apply_image is
+ the transformation expected by the model.
+ """
+ # Expects an image in BCHW format. May not exactly match apply_image.
+ target_size = self.get_preprocess_shape(image.shape[-2], image.shape[-1], self.target_length)
+ if len(image.shape) == 3:
+ image = image.unsqueeze(0)
+ image = F.interpolate(
+ image, target_size,
+ mode="bilinear",
+ align_corners=False,
+ antialias=True
+ )
+ return image.squeeze(0)
+ elif len(image.shape) == 2:
+ image = image.unsqueeze(0).unsqueeze(0)
+ image = F.interpolate(
+ image, target_size,
+ mode="bilinear",
+ align_corners=False,
+ antialias=True
+ )
+ return image.squeeze(0).squeeze(0)
+
+ else:
+ return F.interpolate(
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
+ )
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ if len(x.shape)==2:
+ pass
+ else:
+ device = x.device
+ x = (x - self.pixel_mean.to(device)) / self.pixel_std.to(device) # TODO uncomment this
+ # x = x / 255
+ pass
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.target_length - h
+ padw = self.target_length - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+
+ def apply_coords_torch(
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).to(torch.float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes_torch(
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with shape Bx4. Requires the original image
+ size in (H, W) format.
+ """
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ @staticmethod
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+ """
+ Compute the output size given input size and target long side length.
+ """
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..326776d0f0dcb10584469b59f406f077188f9491
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+torch==2.6.0
+torchvision==0.21.0
+numpy==1.26.4
+matplotlib==3.10.3
+pillow==11.2.1
+opencv-python==4.11.0.86
+tqdm==4.67.1
+sacred==0.8.7
+gradio==5.29.0
+safetensors==0.5.3
+segment-anything==1.0
+kneed==0.8.5
+scikit-image==0.24.0
+scikit-learn==1.6.1
+
+# tqdm==4.67.1
+# mediapipe==0.10.21
+# opencv-python==4.11.0.86
+# numpy==1.26.4
+# pandas==2.2.3
+# torch==2.6.0
+# torchvision==0.21.0
+# scikit-learn==1.6.1
+# matplotlib==3.10.1
+# pillow==11.2.1
+# gradio==5.29.0
diff --git a/run_demo.sh b/run_demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..426aacbabf83f5f4b4f3aa6b1b9c34cd86b27f51
--- /dev/null
+++ b/run_demo.sh
@@ -0,0 +1,14 @@
+# Download SAM model if it doesn't exist
+if [ ! -d "pretrained_model" ]; then
+ mkdir -p pretrained_model
+fi
+
+if [ ! -f "pretrained_model/sam_vit_h.pth" ]; then
+ echo "Downloading SAM ViT-H model..."
+ wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
+ mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
+fi
+
+# Run the app
+echo "Running ProtoSAM demo..."
+python app.py
\ No newline at end of file
diff --git a/run_protosam.sh b/run_protosam.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ea0d34165f6b35166a417f851554df1e638084cb
--- /dev/null
+++ b/run_protosam.sh
@@ -0,0 +1,124 @@
+#!/bin/bash
+set -e
+GPUID1=0
+export CUDA_VISIBLE_DEVICES=$GPUID1
+
+# Configs
+MODEL_NAME='dinov2_l14' # relevant for ALPNET, aviailable: dinov2_l14, dinov2_l14_reg, dinov2_b14, dinov2_b14_reg, dlfcn_res101 (deeplabv3)
+COARSE_PRED_ONLY="False" # True will output the coarse segmentation result
+PROTOSAM_SAM_VER="sam_h" # available: sam_h, sam_b, medsam
+INPUT_SIZE=256 # resolution
+ORGAN="rk" # relevant for MRI and CT, available: rk, lk, liver, spleen
+
+# get modality as arg
+MODALITY=$1
+
+PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training
+ALL_EV=( 0 ) # 5-fold cross validation (0, 1, 2, 3, 4)
+SEED=42
+
+if [ $MODALITY != "ct" ] && [ $MODALITY != "mri" ] && [ $MODALITY != "polyp" ]
+then
+ echo "modality must be either ct ,mri or polyp"
+ exit 1
+fi
+
+if [ $MODALITY == "ct" ]
+then
+ DATASET='SABS_Superpix'
+fi
+if [ $MODALITY == "mri" ]
+then
+ DATASET='CHAOST2_Superpix'
+fi
+if [ $MODALITY == "polyp" ]
+then
+ DATASET='polyps'
+fi
+
+if [ $INPUT_SIZE -gt 256 ]
+then
+ DATASET=${DATASET}'_672'
+fi
+
+NWORKER=4
+LORA=0
+RELOAD_PATH=( "None" )
+SKIP_SLICES="True"
+DO_CCA="True"
+ALL_SCALE=( "MIDDLE") # config of pseudolabels
+
+if [ $MODALITY == "polyp" ]
+then
+ ORGAN="polyps"
+fi
+
+FREE_DESC=""
+CPT="${MODEL_NAME}_${MODALITY}"
+if [ -n "$FREE_DESC" ]
+then
+ CPT="${CPT}_${FREE_DESC}"
+fi
+
+if [ $LORA -ne 0 ]
+then
+ CPT="${CPT}_lora_${LORA}"
+fi
+
+if [ $DO_CCA = "True" ]
+then
+ CPT="${CPT}_cca"
+fi
+
+CPT="${CPT}_grid_${PROTO_GRID}_res_${INPUT_SIZE}_${ORGAN}_fold"
+
+SUPP_ID='[6]'
+if [ $MODALITY == "mri" ]
+then
+ SUPP_ID='[4]'
+fi
+
+echo ===================================
+
+for ((i=0; i<${#ALL_EV[@]}; i++))
+do
+ EVAL_FOLD=${ALL_EV[i]}
+ CPT_W_FOLD="${CPT}_${EVAL_FOLD}"
+ echo $CPT_W_FOLD on GPU $GPUID1
+ for SUPERPIX_SCALE in "${ALL_SCALE[@]}"
+ do
+ PREFIX="test_vfold${EVAL_FOLD}"
+ echo $PREFIX
+ LOGDIR="./test_${MODALITY}/${CPT_W_FOLD}"
+
+ if [ ! -d $LOGDIR ]
+ then
+ mkdir -p $LOGDIR
+ fi
+
+ python3 validation_protosam.py with \
+ "modelname=$MODEL_NAME" \
+ "base_model=alpnet" \
+ "coarse_pred_only=$COARSE_PRED_ONLY" \
+ "protosam_sam_ver=$PROTOSAM_SAM_VER" \
+ "curr_cls=$ORGAN" \
+ 'usealign=True' \
+ 'optim_type=sgd' \
+ reload_model_path=${RELOAD_PATH[i]} \
+ num_workers=$NWORKER \
+ scan_per_load=-1 \
+ 'use_wce=True' \
+ exp_prefix=$PREFIX \
+ 'clsname=grid_proto' \
+ eval_fold=$EVAL_FOLD \
+ dataset=$DATASET \
+ proto_grid_size=$PROTO_GRID \
+ min_fg_data=1 seed=$SEED \
+ save_snapshot_every=$SNAPSHOT_INTERVAL \
+ superpix_scale=$SUPERPIX_SCALE \
+ path.log_dir=$LOGDIR \
+ support_idx=$SUPP_ID \
+ lora=$LORA \
+ "input_size=($INPUT_SIZE, $INPUT_SIZE)"
+ done
+done
\ No newline at end of file
diff --git a/training.py b/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..1719c450f9856d02cd70bf8c33cbcb2ddae114ab
--- /dev/null
+++ b/training.py
@@ -0,0 +1,250 @@
+"""
+Training the model
+Extended from original implementation of ALPNet.
+"""
+from scipy.ndimage import distance_transform_edt as eucl_distance
+import os
+import shutil
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import MultiStepLR
+import numpy as np
+from models.grid_proto_fewshot import FewShotSeg
+from torch.utils.tensorboard import SummaryWriter
+from dataloaders.dev_customized_med import med_fewshot
+from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset
+import dataloaders.augutils as myaug
+
+from util.utils import set_seed, t2n, to01, compose_wt_simple
+from util.metric import Metric
+
+from config_ssl_upload import ex
+from tqdm.auto import tqdm
+# import Tensor
+from torch import Tensor
+from typing import List, Tuple, Union, cast, Iterable, Set, Any, Callable, TypeVar
+
+def get_dice_loss(prediction: torch.Tensor, target: torch.Tensor, smooth=1.0):
+ '''
+ prediction: (B, 1, H, W)
+ target: (B, H, W)
+ '''
+ if prediction.shape[1] > 1:
+ # use only the foreground prediction
+ prediction = prediction[:, 1, :, :]
+ prediction = torch.sigmoid(prediction)
+ intersection = (prediction * target).sum(dim=(-2, -1))
+ union = prediction.sum(dim=(-2, -1)) + target.sum(dim=(1, 2)) + smooth
+
+ dice = (2.0 * intersection + smooth) / union
+ dice_loss = 1.0 - dice.mean()
+
+ return dice_loss
+
+
+def get_train_transforms(_config):
+ tr_transforms = myaug.transform_with_label(
+ {'aug': myaug.get_aug(_config['which_aug'], _config['input_size'][0])})
+ return tr_transforms
+
+
+def get_dataset_base_name(data_name):
+ if data_name == 'SABS_Superpix':
+ baseset_name = 'SABS'
+ elif data_name == 'C0_Superpix':
+ raise NotImplementedError
+ baseset_name = 'C0'
+ elif data_name == 'CHAOST2_Superpix':
+ baseset_name = 'CHAOST2'
+ elif data_name == 'CHAOST2_Superpix_672':
+ baseset_name = 'CHAOST2'
+ elif data_name == 'SABS_Superpix_448':
+ baseset_name = 'SABS'
+ elif data_name == 'SABS_Superpix_672':
+ baseset_name = 'SABS'
+ elif 'lits' in data_name.lower():
+ baseset_name = 'LITS17'
+ else:
+ raise ValueError(f'Dataset: {data_name} not found')
+
+ return baseset_name
+
+def get_nii_dataset(_config):
+ data_name = _config['dataset']
+ baseset_name = get_dataset_base_name(data_name)
+ tr_transforms = get_train_transforms(_config)
+ tr_parent = SuperpixelDataset( # base dataset
+ which_dataset=baseset_name,
+ base_dir=_config['path'][data_name]['data_dir'],
+ idx_split=_config['eval_fold'],
+ mode='train',
+ # dummy entry for superpixel dataset
+ min_fg=str(_config["min_fg_data"]),
+ image_size=_config["input_size"][0],
+ transforms=tr_transforms,
+ nsup=_config['task']['n_shots'],
+ scan_per_load=_config['scan_per_load'],
+ exclude_list=_config["exclude_cls_list"],
+ superpix_scale=_config["superpix_scale"],
+ fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (
+ data_name == 'CHAOST2_Superpix') else _config["max_iters_per_load"],
+ use_clahe=_config['use_clahe'],
+ use_3_slices=_config["use_3_slices"],
+ tile_z_dim=3 if not _config["use_3_slices"] else 1,
+ )
+
+ return tr_parent
+
+
+def get_dataset(_config):
+ return get_nii_dataset(_config)
+
+
+@ex.automain
+def main(_run, _config, _log):
+ precision = torch.float32
+ torch.autograd.set_detect_anomaly(True)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if _run.observers:
+ os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
+ for source_file, _ in _run.experiment_info['sources']:
+ os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
+ exist_ok=True)
+ _run.observers[0].save_file(source_file, f'source/{source_file}')
+ shutil.rmtree(f'{_run.observers[0].basedir}/_sources')
+
+ set_seed(_config['seed'])
+
+ writer = SummaryWriter(f'{_run.observers[0].dir}/logs')
+ _log.info('###### Create model ######')
+ if _config['reload_model_path'] != '':
+ _log.info(f'###### Reload model {_config["reload_model_path"]} ######')
+ else:
+ _config['reload_model_path'] = None
+ model = FewShotSeg(image_size=_config['input_size'][0], pretrained_path=_config['reload_model_path'], cfg=_config['model'])
+
+ model = model.to(device, precision)
+ model.train()
+
+ _log.info('###### Load data ######')
+ data_name = _config['dataset']
+ tr_parent = get_dataset(_config)
+
+ # dataloaders
+ trainloader = DataLoader(
+ tr_parent,
+ batch_size=_config['batch_size'],
+ shuffle=True,
+ num_workers=_config['num_workers'],
+ pin_memory=True,
+ drop_last=True
+ )
+
+ _log.info('###### Set optimizer ######')
+ if _config['optim_type'] == 'sgd':
+ optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
+ elif _config['optim_type'] == 'adam':
+ optimizer = torch.optim.AdamW(
+ model.parameters(), lr=_config['lr'], eps=1e-5)
+ else:
+ raise NotImplementedError
+
+ scheduler = MultiStepLR(
+ optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma'])
+
+ my_weight = compose_wt_simple(_config["use_wce"], data_name)
+ criterion = nn.CrossEntropyLoss(
+ ignore_index=_config['ignore_label'], weight=my_weight)
+
+ i_iter = 0 # total number of iteration
+ # number of times for reloading
+ n_sub_epoches = max(1, _config['n_steps'] // _config['max_iters_per_load'], _config["epochs"])
+ log_loss = {'loss': 0, 'align_loss': 0}
+
+ _log.info('###### Training ######')
+ epoch_losses = []
+ for sub_epoch in range(1):
+ print(f"Epoch: {sub_epoch}")
+ _log.info(
+ f'###### This is epoch {sub_epoch} of {n_sub_epoches} epoches ######')
+ pbar = tqdm(trainloader)
+ optimizer.zero_grad()
+ for idx, sample_batched in enumerate(tqdm(trainloader)):
+ losses = []
+ i_iter += 1
+ support_images = [[shot.to(device, precision) for shot in way]
+ for way in sample_batched['support_images']]
+ support_fg_mask = [[shot[f'fg_mask'].float().to(device, precision) for shot in way]
+ for way in sample_batched['support_mask']]
+ support_bg_mask = [[shot[f'bg_mask'].float().to(device, precision) for shot in way]
+ for way in sample_batched['support_mask']]
+
+ query_images = [query_image.to(device, precision)
+ for query_image in sample_batched['query_images']]
+ query_labels = torch.cat(
+ [query_label.long().to(device) for query_label in sample_batched['query_labels']], dim=0)
+
+ loss = 0.0
+ try:
+ out = model(support_images, support_fg_mask, support_bg_mask,
+ query_images, isval=False, val_wsize=None)
+ query_pred, align_loss, _, _, _, _, _ = out
+ # pred = np.array(query_pred.argmax(dim=1)[0].cpu())
+ except Exception as e:
+ print(f'faulty batch detected, skip: {e}')
+ # offload cuda memory
+ del support_images, support_fg_mask, support_bg_mask, query_images, query_labels
+ continue
+
+ query_loss = criterion(query_pred.float(), query_labels.long())
+ loss += query_loss + align_loss
+ pbar.set_postfix({'loss': loss.item()})
+ loss.backward()
+ if (idx + 1) % _config['grad_accumulation_steps'] == 0:
+ optimizer.step()
+ optimizer.zero_grad()
+ scheduler.step()
+
+ losses.append(loss.item())
+ query_loss = query_loss.detach().data.cpu().numpy()
+ align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0
+
+ _run.log_scalar('loss', query_loss)
+ _run.log_scalar('align_loss', align_loss)
+
+ log_loss['loss'] += query_loss
+ log_loss['align_loss'] += align_loss
+
+ # print loss and take snapshots
+ if (i_iter + 1) % _config['print_interval'] == 0:
+ writer.add_scalar('loss', loss, i_iter)
+ writer.add_scalar('query_loss', query_loss, i_iter)
+ writer.add_scalar('align_loss', align_loss, i_iter)
+
+ loss = log_loss['loss'] / _config['print_interval']
+ align_loss = log_loss['align_loss'] / _config['print_interval']
+
+ log_loss['loss'] = 0
+ log_loss['align_loss'] = 0
+
+ print(
+ f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss},')
+
+ if (i_iter + 1) % _config['save_snapshot_every'] == 0:
+ _log.info('###### Taking snapshot ######')
+ torch.save(model.state_dict(),
+ os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
+
+ if (i_iter - 1) >= _config['n_steps']:
+ break # finish up
+ epoch_losses.append(np.mean(losses))
+ print(f"Epoch {sub_epoch} loss: {np.mean(losses)}")
+
+ # Save the final model regardless of iteration count
+ _log.info('###### Saving final model ######')
+ final_save_path = os.path.join(f'{_run.observers[0].dir}/snapshots', f'final_model.pth')
+ torch.save(model.state_dict(), final_save_path)
+ print(f"Final model saved to: {final_save_path}")
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/util/__pycache__/__init__.cpython-312.pyc b/util/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f096d96f46d25d43b251460b7a559325624530fc
Binary files /dev/null and b/util/__pycache__/__init__.cpython-312.pyc differ
diff --git a/util/__pycache__/consts.cpython-312.pyc b/util/__pycache__/consts.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4830266e1486f0303f5416abcd9fa019d98834b9
Binary files /dev/null and b/util/__pycache__/consts.cpython-312.pyc differ
diff --git a/util/__pycache__/lora.cpython-312.pyc b/util/__pycache__/lora.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ed843c281fdb1bdd443dd31d1a62184b1e04a3b
Binary files /dev/null and b/util/__pycache__/lora.cpython-312.pyc differ
diff --git a/util/__pycache__/utils.cpython-312.pyc b/util/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50d5cd701cd4b13d759acdb6df2313f50dd6f195
Binary files /dev/null and b/util/__pycache__/utils.cpython-312.pyc differ
diff --git a/util/consts.py b/util/consts.py
new file mode 100644
index 0000000000000000000000000000000000000000..d82d6dc0e9d35dd275aa49ec0c7ed4ebfbff3445
--- /dev/null
+++ b/util/consts.py
@@ -0,0 +1,2 @@
+IMG_SIZE=252 # 256 is original
+DEFAULT_FEATURE_SIZE=32
\ No newline at end of file
diff --git a/util/lora.py b/util/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..2533a01fa77df7de5b35e7e63a62bc5f38e5cf61
--- /dev/null
+++ b/util/lora.py
@@ -0,0 +1,1113 @@
+# copied from https://github.com/cloneofsimo/lora
+import json
+import math
+from itertools import groupby
+from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ from safetensors.torch import safe_open
+ from safetensors.torch import save_file as safe_save
+
+ safetensors_available = True
+except ImportError:
+ from .safe_open import safe_open
+
+ def safe_save(
+ tensors: Dict[str, torch.Tensor],
+ filename: str,
+ metadata: Optional[Dict[str, str]] = None,
+ ) -> None:
+ raise EnvironmentError(
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
+ )
+
+ safetensors_available = False
+
+
+class LoraInjectedLinear(nn.Module):
+ def __init__(
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
+ ):
+ super().__init__()
+
+ if r > min(in_features, out_features):
+ raise ValueError(
+ f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
+ )
+ self.r = r
+ self.linear = nn.Linear(in_features, out_features, bias)
+ self.lora_down = nn.Linear(in_features, r, bias=False)
+ self.dropout = nn.Dropout(dropout_p)
+ self.lora_up = nn.Linear(r, out_features, bias=False)
+ self.scale = scale
+ self.selector = nn.Identity()
+
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
+ nn.init.zeros_(self.lora_up.weight)
+
+ def forward(self, input):
+ return (
+ self.linear(input)
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
+ * self.scale
+ )
+
+ def realize_as_lora(self):
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
+
+ def set_selector_from_diag(self, diag: torch.Tensor):
+ # diag is a 1D tensor of size (r,)
+ assert diag.shape == (self.r,)
+ self.selector = nn.Linear(self.r, self.r, bias=False)
+ self.selector.weight.data = torch.diag(diag)
+ self.selector.weight.data = self.selector.weight.data.to(
+ self.lora_up.weight.device
+ ).to(self.lora_up.weight.dtype)
+
+
+class LoraInjectedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups: int = 1,
+ bias: bool = True,
+ r: int = 4,
+ dropout_p: float = 0.1,
+ scale: float = 1.0,
+ ):
+ super().__init__()
+ if r > min(in_channels, out_channels):
+ raise ValueError(
+ f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
+ )
+ self.r = r
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+
+ self.lora_down = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=r,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=False,
+ )
+ self.dropout = nn.Dropout(dropout_p)
+ self.lora_up = nn.Conv2d(
+ in_channels=r,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.selector = nn.Identity()
+ self.scale = scale
+
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
+ nn.init.zeros_(self.lora_up.weight)
+
+ def forward(self, input):
+ return (
+ self.conv(input)
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
+ * self.scale
+ )
+
+ def realize_as_lora(self):
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
+
+ def set_selector_from_diag(self, diag: torch.Tensor):
+ # diag is a 1D tensor of size (r,)
+ assert diag.shape == (self.r,)
+ self.selector = nn.Conv2d(
+ in_channels=self.r,
+ out_channels=self.r,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.selector.weight.data = torch.diag(diag)
+
+ # same device + dtype as lora_up
+ self.selector.weight.data = self.selector.weight.data.to(
+ self.lora_up.weight.device
+ ).to(self.lora_up.weight.dtype)
+
+
+UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
+
+UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
+
+TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
+
+TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
+
+DINO_TARGET_REPLACE = {"NestedTensorBlock", "Mlp", "Attention", "MemEffAttention"}
+
+DEFAULT_TARGET_REPLACE = DINO_TARGET_REPLACE
+
+EMBED_FLAG = "