diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3e88c65dff890b23ddb559b362dc65682456f201 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +exps/ +.vscode/ +debug/ +test_*/ +pretrained_model/* \ No newline at end of file diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3877ae0a7ff6f94ac222fd704e112723db776114 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 6e887e497ce53f89ee0d3d506fdd06fae406ead7..53b5017d8040fe9d2361cc9f60f6cf435d94a5dd 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,157 @@ ---- -title: LoGoSAM Demo -emoji: 📈 -colorFrom: purple -colorTo: yellow -sdk: gradio -sdk_version: 5.30.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: LoGoSAM_demo +app_file: app.py +sdk: gradio +sdk_version: 5.29.0 +--- +# ProtoSAM - One shot segmentation with foundational models + +Link to our paper [here](https://arxiv.org/abs/2407.07042). \ +This work is the successor of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning) (Link to [Paper](arxiv.org/abs/2403.03273)). + +## Demo Application + +A Gradio-based demo application is now available for interactive inference with ProtoSAM. You can upload your own images and masks to test the model. See [README_DEMO.md](README_DEMO.md) for instructions on running the demo. + +## Abstract +This work introduces a new framework, ProtoSAM, for one-shot image segmentation. It combines DINOv2, a vision transformer that extracts features from images, with an Adaptive Local Prototype Pooling (ALP) layer, which generates prototypes from a support image and its mask. These prototypes are used to create an initial coarse segmentation mask by comparing the query image's features with the prototypes. +Following the extraction of an initial mask, we use numerical methods to generate prompts, such as points and bounding boxes, which are then input into the Segment Anything Model (SAM), a prompt-based segmentation model trained on natural images. This allows segmenting new classes automatically and effectively without the need for additional training. + +## How To Run +### 1. Data preprocessing +#### 1.1 CT and MRI Dataset +Please see the notebook `data/data_processing.ipynb` for instructions. +For convenience i've compiled the data processing instructions from https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation to a single notebook. \ +The CT dataset is available here: https://www.synapse.org/Synapse:syn3553734 \ +The MRI dataset is availabel here: https://chaos.grand-challenge.org + +run `./data/CHAOST2/dcm_img_to_nii.sh` to convert dicom images to nifti files. + +#### 1.2 Polyp Dataset +Data is available here: https://www.kaggle.com/datasets/hngphmv/polypdataset?select=train.csv + +Put the dataset `data/PolypDataset/` + +### 2. Running +#### 2.1 (Optional) Training and Validation of the coarse segmentation networks +``` +./backbone.sh [MODE] [MODALITY] [LABEL_SET] +``` +MODE - validation or training \ +MODALITY - ct or mri \ +LABEL_SET - 0 (kidneys), 1 (liver spleen) + +for example: +``` +./backbone.sh training mri 1 +``` +Please refer to `backbone.sh` for further configurations. + +#### 2.1 Running ProtoSAM +Put all SAM checkpoint like sam_vit_b.pth, sam_vit_h.pth, medsam_vit_b.pth into the `pretrained_model` directory. \ +Checkpoints are available at [SAM](https://github.com/facebookresearch/segment-anything) and [MedSAM](https://github.com/bowang-lab/MedSAM). + +``` +./run_protosam.sh [MODALITY] [LABEL_SET] +``` +MODALITY - ct, mri or polyp \ +LABEL_SET (only relevant if doing ct or mri) - 0 (kidneys), 1 (liver spleen) +Please refer to the `run_protosam.sh` script for further configurations. + + +## Acknowledgements +This work is largely based on [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation), [DINOv2](https://github.com/facebookresearch/dinov2), [SAM](https://github.com/facebookresearch/segment-anything) and is a continuation of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning). + +## Cite +If you found this repo useful, please consider giving us a citation and a star! + +```bibtex +@article{ayzenberg2024protosam, + title={ProtoSAM-One Shot Medical Image Segmentation With Foundational Models}, + author={Ayzenberg, Lev and Giryes, Raja and Greenspan, Hayit}, + journal={arXiv preprint arXiv:2407.07042}, + year={2024} +} + +@misc{ayzenberg2024dinov2, + title={DINOv2 based Self Supervised Learning For Few Shot Medical Image Segmentation}, + author={Lev Ayzenberg and Raja Giryes and Hayit Greenspan}, + year={2024}, + eprint={2403.03273}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +``` + +# ProtoSAM Segmentation Demo + +This Streamlit application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction. + +## Requirements + +- Python 3.8 or higher +- CUDA-compatible GPU +- Required Python packages (see `requirements.txt`) + +## Setup Instructions + +1. Clone this repository: +```bash +git clone +cd +``` + +2. Create and activate a virtual environment (optional but recommended): +```bash +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate +``` + +3. Install the required dependencies: +```bash +pip install -r requirements.txt +``` + +4. Download the pretrained models: +```bash +mkdir -p pretrained_model +# Download SAM ViT-H model +wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth +mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth +``` + +5. Update the model path in `app.py`: + - Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model. + +## Running the App + +Start the Streamlit app with: +```bash +streamlit run app.py +``` + +This will open a browser window with the interface for the segmentation demo. + +## Usage + +1. Upload a query image (the image you want to segment) +2. Upload a support image (an example image with a similar object) +3. Upload a support mask (the segmentation mask for the support image) +4. Use the sidebar to configure the model parameters if needed +5. Click "Run Inference" to generate the segmentation result + +## Model Configuration + +The app allows you to configure several model parameters via the sidebar: +- Use Bounding Box: Enable/disable bounding box input +- Use Points: Enable/disable point input +- Use Mask: Enable/disable mask input +- Use CCA: Enable/disable Connected Component Analysis +- Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement + +## Notes + +- This demo requires a GPU with CUDA support +- Large images may require more GPU memory +- For optimal results, use high-quality support images and masks diff --git a/README_DEMO.md b/README_DEMO.md new file mode 100644 index 0000000000000000000000000000000000000000..d11c2f105a573ad9df1acd14c4fc14d396bb4a95 --- /dev/null +++ b/README_DEMO.md @@ -0,0 +1,76 @@ +# ProtoSAM Segmentation Demo + +This Gradio application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction. + +## Requirements + +- Python 3.8 or higher +- CUDA-compatible GPU +- Required Python packages (see `requirements.txt`) + +## Setup Instructions + +1. Clone this repository: +```bash +git clone +cd +``` + +2. Create and activate a virtual environment (optional but recommended): +```bash +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate +``` + +3. Install the required dependencies: +```bash +pip install -r requirements.txt +``` + +4. Download the pretrained models: +```bash +mkdir -p pretrained_model +# Download SAM ViT-H model +wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth +mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth +``` + +5. Update the model path in `app.py`: + - Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model. + +## Running the App + +Start the app with: +```bash +./run_demo.sh +``` + +Or run it directly with: +```bash +python app.py +``` + +This will start the server and provide a link to access the demo in your browser. + +## Usage + +1. Upload a query image (the image you want to segment) +2. Upload a support image (an example image with a similar object) +3. Upload a support mask (the segmentation mask for the support image) +4. Configure the model parameters using the checkboxes +5. Click "Run Inference" to generate the segmentation result + +## Model Configuration + +The app allows you to configure several model parameters: +- Use Bounding Box: Enable/disable bounding box input +- Use Points: Enable/disable point input +- Use Mask: Enable/disable mask input +- Use CCA: Enable/disable Connected Component Analysis +- Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement + +## Notes + +- This demo requires a GPU with CUDA support +- Large images may require more GPU memory +- For optimal results, use high-quality support images and masks \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..dc11ba7160f0fe2dfa97b7c28bbb81683e686e6c --- /dev/null +++ b/app.py @@ -0,0 +1,247 @@ +import os +import torch +import numpy as np +import matplotlib.pyplot as plt +import gradio as gr +from PIL import Image +import torchvision.transforms as transforms +from models.ProtoSAM import ProtoSAM, ALPNetWrapper, InputFactory, TYPE_ALPNET +from models.grid_proto_fewshot import FewShotSeg +from models.segment_anything.utils.transforms import ResizeLongestSide + +# Set environment variables for model caching +os.environ['TORCH_HOME'] = "./pretrained_model" + +# Function to load the model +def load_model(config): + # Initial segmentation model + alpnet = FewShotSeg( + config["input_size"][0], + config["reload_model_path"], + config["model"] + ) + alpnet.cuda() + base_model = ALPNetWrapper(alpnet) + + # ProtoSAM model + sam_checkpoint = "pretrained_model/sam_vit_h.pth" + model = ProtoSAM( + image_size=(1024, 1024), + coarse_segmentation_model=base_model, + use_bbox=config["use_bbox"], + use_points=config["use_points"], + use_mask=config["use_mask"], + debug=False, + num_points_for_sam=1, + use_cca=config["do_cca"], + point_mode=config["point_mode"], + use_sam_trans=True, + coarse_pred_only=config["coarse_pred_only"], + sam_pretrained_path=sam_checkpoint, + use_neg_points=config["use_neg_points"], + ) + model = model.to(torch.device("cuda")) + model.eval() + return model + +# Function to preprocess images +def preprocess_image(image, transform): + if isinstance(image, np.ndarray): + image_np = image + else: + # Convert PIL Image to numpy array + image_np = np.array(image) + + # Convert to RGB if grayscale + if len(image_np.shape) == 2: + image_np = np.stack([image_np] * 3, axis=2) + elif image_np.shape[2] == 1: + image_np = np.concatenate([image_np] * 3, axis=2) + + # Apply transforms + image_tensor = transform(image_np).unsqueeze(0) + return image_tensor + +# Function to create overlay visualization +def create_overlay(query_image, prediction, colormap='YlOrRd'): + """ + Create an overlay of the prediction on the query image + """ + # Convert tensors to numpy arrays for visualization + if isinstance(query_image, torch.Tensor): + query_image = query_image.cpu().squeeze().numpy() + + if isinstance(prediction, torch.Tensor): + prediction = prediction.cpu().squeeze().numpy() + + # Normalize image for visualization + query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8) + + # Ensure binary mask + prediction = (prediction > 0).astype(np.float32) + + # Create mask overlay + mask_cmap = plt.cm.get_cmap(colormap) + pred_rgba = mask_cmap(prediction) + pred_rgba[..., 3] = prediction * 0.7 # Set alpha channel + + # Create matplotlib figure for overlay + fig, ax = plt.subplots(figsize=(10, 10)) + + # Handle grayscale vs RGB images + if len(query_image.shape) == 2: + ax.imshow(query_image, cmap='gray') + else: + if query_image.shape[0] == 3: # Channel-first format + query_image = np.transpose(query_image, (1, 2, 0)) + ax.imshow(query_image) + + ax.imshow(pred_rgba) + ax.axis('off') + plt.tight_layout() + + # Convert to PIL Image + fig.canvas.draw() + img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close(fig) + + return img + +# Model configuration +config = { + "input_size": [224], + "reload_model_path": "path/to/your/model.pth", # Update with your model path + "model": {"encoder": "resnet50", "decoder": "pspnet"}, + "use_bbox": True, + "use_points": True, + "use_mask": True, + "do_cca": True, + "point_mode": "extreme", + "coarse_pred_only": False, + "use_neg_points": False, + "base_model": TYPE_ALPNET +} + +# Function to run inference +def run_inference(query_image, support_image, support_mask, use_bbox, use_points, use_mask, use_cca, coarse_pred_only): + try: + # Update config based on user selections + config["use_bbox"] = use_bbox + config["use_points"] = use_points + config["use_mask"] = use_mask + config["do_cca"] = use_cca + config["coarse_pred_only"] = coarse_pred_only + + # Check if CUDA is available + if not torch.cuda.is_available(): + return None, "CUDA is not available. This demo requires GPU support." + + # Load the model + model = load_model(config) + + # Preprocess images + sam_trans = ResizeLongestSide(1024) + + # Transform for images + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((1024, 1024), antialias=True), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Process query image + query_img_tensor = preprocess_image(query_image, transform) + + # Process support image + support_img_tensor = preprocess_image(support_image, transform) + + # Process support mask (should be binary) + support_mask_np = np.array(support_mask) + support_mask_np = (support_mask_np > 127).astype(np.float32) # Binarize mask + support_mask_tensor = torch.from_numpy(support_mask_np).unsqueeze(0).unsqueeze(0) + support_mask_tensor = torch.nn.functional.interpolate( + support_mask_tensor, size=(1024, 1024), mode='nearest' + ) + + # Prepare model inputs + support_images = [support_img_tensor.cuda()] + support_masks = [support_mask_tensor.cuda()] + + # Create model input + coarse_model_input = InputFactory.create_input( + input_type=config["base_model"], + query_image=query_img_tensor.cuda(), + support_images=support_images, + support_labels=support_masks, + isval=True, + val_wsize=3, + original_sz=query_img_tensor.shape[-2:], + img_sz=query_img_tensor.shape[-2:], + gts=None, + ) + coarse_model_input.to(torch.device("cuda")) + + # Run inference + with torch.no_grad(): + query_pred, scores = model( + query_img_tensor.cuda(), coarse_model_input, degrees_rotate=0 + ) + + # Create overlay visualization + result_image = create_overlay(query_img_tensor, query_pred) + + confidence_score = np.mean(scores) + return result_image, f"Confidence Score: {confidence_score:.4f}" + + except Exception as e: + return None, f"Error during inference: {str(e)}" + +# Define the Gradio interface +def create_interface(): + with gr.Blocks(title="ProtoSAM Segmentation Demo") as demo: + gr.Markdown("# ProtoSAM Segmentation Demo") + gr.Markdown("Upload a query image, support image, and support mask to generate a segmentation prediction.") + + with gr.Row(): + with gr.Column(): + query_image = gr.Image(label="Query Image", type="pil") + support_image = gr.Image(label="Support Image", type="pil") + support_mask = gr.Image(label="Support Mask", type="pil") + + with gr.Column(): + result_image = gr.Image(label="Prediction Result") + result_text = gr.Textbox(label="Result Information") + + with gr.Row(): + with gr.Column(): + use_bbox = gr.Checkbox(label="Use Bounding Box", value=True) + use_points = gr.Checkbox(label="Use Points", value=True) + use_mask = gr.Checkbox(label="Use Mask", value=True) + + with gr.Column(): + use_cca = gr.Checkbox(label="Use CCA", value=True) + coarse_pred_only = gr.Checkbox(label="Coarse Prediction Only", value=False) + run_button = gr.Button("Run Inference") + + run_button.click( + fn=run_inference, + inputs=[ + query_image, + support_image, + support_mask, + use_bbox, + use_points, + use_mask, + use_cca, + coarse_pred_only + ], + outputs=[result_image, result_text] + ) + + return demo + +# Create and launch the interface +if __name__ == "__main__": + demo = create_interface() + demo.launch(share=True) \ No newline at end of file diff --git a/backbone.sh b/backbone.sh new file mode 100644 index 0000000000000000000000000000000000000000..4319e8f737767d046539f3d0fe94ff704264ad68 --- /dev/null +++ b/backbone.sh @@ -0,0 +1,179 @@ +#!/bin/bash +set -e +GPUID1=0 +export CUDA_VISIBLE_DEVICES=$GPUID1 + +MODE=$1 +if [ $MODE != "validation" ] && [ $MODE != "training" ] +then + echo "mode must be either validation or training" + exit 1 +fi + +# get modality as arg +MODALITY=$2 +# make sure modality is either ct or mri +if [ $MODALITY != "ct" ] && [ $MODALITY != "mri" ] +then + echo "modality must be either ct or mri" + exit 1 +fi + +####### Shared configs ###### +PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training +INPUT_SIZE=256 +ALL_EV=( 0 ) # 5-fold cross validation (0, 1, 2, 3, 4) +if [ $MODALITY == "ct" ] +then + DATASET='SABS_Superpix' +else + DATASET='CHAOST2_Superpix' +fi + +if [ $INPUT_SIZE -gt 256 ] +then + DATASET=${DATASET}'_672' +fi + +NWORKER=4 +MODEL_NAME='dinov2_l14' +LORA=0 +RELOAD_PATH=( "None" ) +SKIP_SLICES="True" +DO_CCA="True" +TTT="False" +NSTEP=100000 +RESET_AFTER_SLICE="True" +FINETUNE_ON_SUPPORT="False" +USE_SLICE_ADAPTER="False" +ADAPTER_LAYERS=1 +CLAHE=False +ALL_SCALE=( "MIDDLE") # config of pseudolabels + +LABEL_SETS=$3 +EXCLU='[2,3]' + +if [[ $MODALITY == "mri" && $LABEL_SETS -eq 1 ]] +then + echo "exluding 1, 4" + EXCLU='[1,4]' # liver(1), spleen(4) +fi + +ORGANS='kidneys' +if [ $LABEL_SETS -eq 1 ] +then + ORGANS='liver_spleen' +fi + + +FREE_DESC="" +CPT="${MODE}_${MODEL_NAME}_${MODALITY}" +if [ -n "$FREE_DESC" ] +then + CPT="${CPT}_${FREE_DESC}" +fi + +if [[ $TTT == "True" ]] +then + CPT="${CPT}_ttt_nstep_${NSTEP}" + if [ $RESET_AFTER_SLICE == "True" ] + then + CPT="${CPT}_reset_after_slice" + fi +fi + +if [ $USE_SLICE_ADAPTER == "True" ] +then + CPT="${CPT}_w_adapter_${ADAPTER_LAYERS}_layers" +fi + +if [ $LORA -ne 0 ] +then + CPT="${CPT}_lora_${LORA}" +fi + +if [ $CLAHE == "True" ] +then + CPT="${CPT}_w_clahe" +fi + +if [ $DO_CCA = "True" ] +then + CPT="${CPT}_cca" +fi + +CPT="${CPT}_grid_${PROTO_GRID}_res_${INPUT_SIZE}" + +if [ ${EXCLU} = "[]" ] +then + CPT="${CPT}_setting1" +else + CPT="${CPT}_setting2" +fi + +CPT="${CPT}_${ORGANS}_fold" + +###### Training configs (irrelavent in testing) ###### +DECAY=0.95 + +MAX_ITER=1000 # defines the size of an epoch +SNAPSHOT_INTERVAL=25000 # interval for saving snapshot +SEED='1234' + +###### Validation configs ###### +SUPP_ID='[6]' # using the additionally loaded scan as support +if [ $MODALITY == "mri" ] +then + SUPP_ID='[4]' +fi + +echo =================================== + +for ((i=0; i<${#ALL_EV[@]}; i++)) +do + EVAL_FOLD=${ALL_EV[i]} + CPT_W_FOLD="${CPT}_${EVAL_FOLD}" + echo $CPT_W_FOLD on GPU $GPUID1 + for SUPERPIX_SCALE in "${ALL_SCALE[@]}" + do + PREFIX="test_vfold${EVAL_FOLD}" + echo $PREFIX + LOGDIR="./test_${MODALITY}/${CPT_W_FOLD}" + + if [ ! -d $LOGDIR ] + then + mkdir -p $LOGDIR + fi + +python3 $MODE.py with \ + "modelname=$MODEL_NAME" \ + 'usealign=True' \ + 'optim_type=sgd' \ + reload_model_path=${RELOAD_PATH[i]} \ + num_workers=$NWORKER \ + scan_per_load=-1 \ + label_sets=$LABEL_SETS \ + 'use_wce=True' \ + exp_prefix=$PREFIX \ + 'clsname=grid_proto' \ + n_steps=$NSTEP \ + exclude_cls_list=$EXCLU \ + eval_fold=$EVAL_FOLD \ + dataset=$DATASET \ + proto_grid_size=$PROTO_GRID \ + max_iters_per_load=$MAX_ITER \ + min_fg_data=1 seed=$SEED \ + save_snapshot_every=$SNAPSHOT_INTERVAL \ + superpix_scale=$SUPERPIX_SCALE \ + lr_step_gamma=$DECAY \ + path.log_dir=$LOGDIR \ + support_idx=$SUPP_ID \ + lora=$LORA \ + do_cca=$DO_CCA \ + ttt=$TTT \ + adapter_layers=$ADAPTER_LAYERS \ + use_slice_adapter=$USE_SLICE_ADAPTER \ + reset_after_slice=$RESET_AFTER_SLICE \ + "input_size=($INPUT_SIZE, $INPUT_SIZE)" + done +done \ No newline at end of file diff --git a/config_ssl_upload.py b/config_ssl_upload.py new file mode 100644 index 0000000000000000000000000000000000000000..fba41dd9acfc415c6256dcd1ef2e5edb6eff4d1a --- /dev/null +++ b/config_ssl_upload.py @@ -0,0 +1,177 @@ +""" +Experiment configuration file +Extended from config file from original PANet Repository +""" +import os +import re +import glob +import itertools + +import sacred +from sacred import Experiment +from sacred.observers import FileStorageObserver +from sacred.utils import apply_backspaces_and_linefeeds + +from platform import node +from datetime import datetime + +from util.consts import IMG_SIZE + +sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False +sacred.SETTINGS.CAPTURE_MODE = 'no' + +ex = Experiment('mySSL') +ex.captured_out_filter = apply_backspaces_and_linefeeds + +source_folders = ['.', './dataloaders', './models', './util'] +sources_to_save = list(itertools.chain.from_iterable( + [glob.glob(f'{folder}/*.py') for folder in source_folders])) +for source_file in sources_to_save: + ex.add_source_file(source_file) + +@ex.config +def cfg(): + """Default configurations""" + seed = 1234 + gpu_id = 0 + mode = 'train' # for now only allows 'train' + do_validation=False + num_workers = 4 # 0 for debugging. + + dataset = 'CHAOST2' # i.e. abdominal MRI + use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images + + ### Training + n_steps = 100100 + batch_size = 1 + lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)] + lr_step_gamma = 0.95 + ignore_label = 255 + print_interval = 100 + save_snapshot_every = 25000 + max_iters_per_load = 1000 # epoch size, interval for reloading the dataset + epochs=1 + scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory + which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms + input_size = (IMG_SIZE, IMG_SIZE) + min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process + label_sets = 0 # which group of labels taking as training (the rest are for testing) + curr_cls = "" # choose between rk, lk, spleen and liver + exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1 + usealign = True # see vanilla PANet + use_wce = True + use_dinov2_loss = False + dice_loss = False + ### Validation + z_margin = 0 + eval_fold = 0 # which fold for 5 fold cross validation + support_idx=[-1] # indicating which scan is used as support in testing. + val_wsize=2 # L_H, L_W in testing + n_sup_part = 3 # number of chuncks in testing + use_clahe = False + use_slice_adapter = False + adapter_layers=3 + debug=True + skip_no_organ_slices=True + # Network + modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab + clsname = None # + reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization) + proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training + feature_hw = [input_size[0]//8, input_size[0]//8] # feature map size, should couple this with backbone in future + lora = 0 + use_3_slices=False + do_cca=False + use_edge_detector=False + finetune_on_support=False + sliding_window_confidence_segmentation=False + finetune_model_on_single_slice=False + online_finetuning=True + + use_bbox=True # for SAM + use_points=True # for SAM + use_mask=False # for SAM + base_model="alpnet" # or "SAM" + # SSL + superpix_scale = 'MIDDLE' #MIDDLE/ LARGE + use_pos_enc=False + support_txt_file = None # path to a txt file containing support slices + augment_support_set=False + coarse_pred_only=False # for ProtoSAM + point_mode="both" # for ProtoSAM, choose: both, conf, centroid + use_neg_points=False + n_support=1 # num support images + protosam_sam_ver="sam_h" # or medsam + grad_accumulation_steps=1 + ttt=False + reset_after_slice=True # for TTT, if to reset the model after finetuning on each slice + model = { + 'align': usealign, + 'dinov2_loss': use_dinov2_loss, + 'use_coco_init': use_coco_init, + 'which_model': modelname, + 'cls_name': clsname, + 'proto_grid_size' : proto_grid_size, + 'feature_hw': feature_hw, + 'reload_model_path': reload_model_path, + 'lora': lora, + 'use_slice_adapter': use_slice_adapter, + 'adapter_layers': adapter_layers, + 'debug': debug, + 'use_pos_enc': use_pos_enc + } + + task = { + 'n_ways': 1, + 'n_shots': 1, + 'n_queries': 1, + 'npart': n_sup_part + } + + optim_type = 'sgd' + lr=1e-3 + momentum=0.9 + weight_decay=0.0005 + optim = { + 'lr': lr, + 'momentum': momentum, + 'weight_decay': weight_decay + } + + exp_prefix = '' + + exp_str = '_'.join( + [exp_prefix] + + [dataset,] + + [f'sets_{label_sets}_{task["n_shots"]}shot']) + + path = { + 'log_dir': './runs', + 'SABS':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized" + }, + 'SABS_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448" + }, + 'SABS_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672" + }, + 'C0':{'data_dir': "feed your dataset path here" + }, + 'CHAOST2':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized" + }, + 'CHAOST2_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/" + }, + 'SABS_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"}, + 'C0_Superpix':{'data_dir': "feed your dataset path here"}, + 'CHAOST2_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"}, + 'CHAOST2_Superpix_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"}, + 'SABS_Superpix_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"}, + 'SABS_Superpix_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"}, + } + + +@ex.config_hook +def add_observer(config, command_name, logger): + """A hook fucntion to add observer""" + exp_name = f'{ex.path}_{config["exp_str"]}' + observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) + ex.observers.append(observer) + return config diff --git a/data/data_processing.ipynb b/data/data_processing.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..76cdd5bea9eecf70b9d10838a9a19b8bb571de59 --- /dev/null +++ b/data/data_processing.ipynb @@ -0,0 +1,2687 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For convinience, I've unified all of the data preprocessing\n", + "notebooks from [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation.git) into a single notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%reset\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "import numpy as np\n", + "import os\n", + "import glob\n", + "import SimpleITK as sitk\n", + "import sys\n", + "\n", + "sys.path.insert(0, '../')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Create dirs for the SABS and CHAOS datasets\n", + "os.makedirs('./SABS', exist_ok=True)\n", + "os.makedirs('./CHAOST2', exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def copy_spacing_ori(src, dst):\n", + " dst.SetSpacing(src.GetSpacing())\n", + " dst.SetOrigin(src.GetOrigin())\n", + " dst.SetDirection(src.GetDirection())\n", + " return dst\n", + "\n", + "# helper functions copy pasted\n", + "def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):\n", + " resample = sitk.ResampleImageFilter()\n", + " resample.SetInterpolator(interpolator)\n", + " resample.SetOutputDirection(mov_img_obj.GetDirection())\n", + " resample.SetOutputOrigin(mov_img_obj.GetOrigin())\n", + " mov_spacing = mov_img_obj.GetSpacing()\n", + "\n", + " resample.SetOutputSpacing(new_spacing)\n", + " RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)\n", + " new_size = np.array(mov_img_obj.GetSize()) * RES_COE \n", + "\n", + " resample.SetSize( [int(sz+1) for sz in new_size] )\n", + " if logging:\n", + " print(\"Spacing: {} -> {}\".format(mov_spacing, new_spacing))\n", + " print(\"Size {} -> {}\".format( mov_img_obj.GetSize(), new_size ))\n", + "\n", + " return resample.Execute(mov_img_obj)\n", + "\n", + "def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):\n", + " src_mat = sitk.GetArrayFromImage(mov_lb_obj)\n", + " lbvs = np.unique(src_mat)\n", + " if logging:\n", + " print(\"Label values: {}\".format(lbvs))\n", + " for idx, lbv in enumerate(lbvs):\n", + " _src_curr_mat = np.float32(src_mat == lbv) \n", + " _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)\n", + " _src_curr_obj.CopyInformation(mov_lb_obj)\n", + " _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )\n", + " _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv\n", + " if idx == 0:\n", + " out_vol = _tar_curr_mat\n", + " else:\n", + " out_vol[_tar_curr_mat == lbv] = lbv\n", + " out_obj = sitk.GetImageFromArray(out_vol)\n", + " out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )\n", + " if ref_img != None:\n", + " out_obj.CopyInformation(ref_img)\n", + " return out_obj\n", + " \n", + "## Then crop ROI\n", + "def get_label_center(label):\n", + " nnz = np.sum(label > 1e-5)\n", + " return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))\n", + "\n", + "def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):\n", + " \"\"\" crop a 3d matrix given the index of the new volume on the original volume\n", + " Args:\n", + " refernce_ctr_idx: the center of the new volume on the original volume (in indices)\n", + " only_2d: only do cropping on first two dimensions\n", + " \"\"\"\n", + " _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case\n", + " if only_2d:\n", + " assert len(crop_size) == 2, \"Actual len {}\".format(len(crop_size))\n", + " assert len(referece_ctr_idx) == 2, \"Actual len {}\".format(len(referece_ctr_idx))\n", + " _expand_cropsize.append(ori_vol.shape[-1])\n", + " \n", + " image_patch = np.ones(tuple(_expand_cropsize)) * padval\n", + "\n", + " half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )\n", + " _min_idx = [0,0,0]\n", + " _max_idx = list(ori_vol.shape)\n", + "\n", + " # bias of actual cropped size to the beginning and the end of this volume\n", + " _bias_start = [0,0,0]\n", + " _bias_end = [0,0,0]\n", + "\n", + " for dim,hsize in enumerate(half_size):\n", + " if dim == 2 and only_2d:\n", + " break\n", + "\n", + " _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])\n", + " _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])\n", + "\n", + " _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]\n", + " _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]\n", + " \n", + " if only_2d:\n", + " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", + " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \\\n", + " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", + " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]\n", + "\n", + " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]\n", + " # then goes back to original volume\n", + " else:\n", + " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", + " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \\\n", + " half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \\\n", + " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", + " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \\\n", + " referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]\n", + "\n", + " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]\n", + " return image_patch\n", + "\n", + "s2n = sitk.GetArrayFromImage\n", + "\n", + "\n", + "def resample_imgs(imgs, segs, pids, scan_dir, BD_BIAS, SPA_FAC, required_res=512):\n", + " spa_fac = SPA_FAC\n", + " for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", + "\n", + " # lb_n = nio.read_nii_bysitk(seg_fid)\n", + "\n", + " img_obj = sitk.ReadImage( img_fid )\n", + " seg_obj = sitk.ReadImage( seg_fid )\n", + " print(img_fid, seg_fid)\n", + " ## image\n", + " array = sitk.GetArrayFromImage(img_obj)\n", + " H = W = array.shape[-1]\n", + " if SPA_FAC is None:\n", + " spa_fac = (H - 2 * BD_BIAS) / required_res\n", + " print(array.shape, f\"label shape {sitk.GetArrayFromImage(seg_obj).shape}\")\n", + " # cropping\n", + " array = array[:, BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS]\n", + " cropped_img_o = sitk.GetImageFromArray(array)\n", + " cropped_img_o = copy_spacing_ori(img_obj, cropped_img_o)\n", + "\n", + " # resampling\n", + " img_spa_ori = img_obj.GetSpacing()\n", + " res_img_o = resample_by_res(cropped_img_o, [img_spa_ori[0] * spa_fac, img_spa_ori[1] * spa_fac, img_spa_ori[-1]], interpolator = sitk.sitkLinear,\n", + " logging = True)\n", + "\n", + " ## label\n", + " lb_arr = sitk.GetArrayFromImage(seg_obj)\n", + " # cropping\n", + " lb_arr = lb_arr[:,BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS]\n", + " cropped_lb_o = sitk.GetImageFromArray(lb_arr)\n", + " cropped_lb_o = copy_spacing_ori(seg_obj, cropped_lb_o)\n", + "\n", + " lb_spa_ori = seg_obj.GetSpacing()\n", + "\n", + " # resampling\n", + " res_lb_o = resample_lb_by_res(cropped_lb_o, [lb_spa_ori[0] * spa_fac, lb_spa_ori[1] * spa_fac, lb_spa_ori[-1] ], interpolator = sitk.sitkLinear,\n", + " ref_img = res_img_o, logging = True)\n", + "\n", + " \n", + " out_img_fid = os.path.join( scan_dir, f'image_{pid}.nii.gz' )\n", + " out_lb_fid = os.path.join( scan_dir, f'label_{pid}.nii.gz' ) \n", + " \n", + " # then save\n", + " sitk.WriteImage(res_img_o, out_img_fid, True) \n", + " sitk.WriteImage(res_lb_o, out_lb_fid, True) \n", + " print(f\"{out_img_fid} has been saved, shape: {res_img_o.GetSize()}\")\n", + " print(f\"{out_lb_fid} has been saved\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Intensitiy Normalization for CT Images" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# set up directories for images\n", + "IMG_FOLDER=\"./miccai2015/RawData/Training/img\"\n", + "SEG_FOLDER=\"./miccai2015/RawData/Training/label\"\n", + "OUT_FOLDER=\"./SABS/tmp_normalized/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n" + ] + } + ], + "source": [ + "imgs = sorted(glob.glob(IMG_FOLDER + \"/*.nii.gz\"))\n", + "segs = sorted(glob.glob(SEG_FOLDER + \"/*.nii.gz\"))\n", + "pids = [pid.split(\"img\")[-1].split(\".\")[0] for pid in imgs]\n", + "print(sorted(pids))\n", + "assert len(imgs) == len(segs)\n", + "for img, seg in zip(imgs, segs):\n", + " print(img, seg)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(147, 512, 512) label shape (147, 512, 512)\n", + "./SABS/tmp_normalized/image_0.nii.gz has been save\n", + "./SABS/tmp_normalized/label_0.nii.gz has been save\n", + "(139, 512, 512) label shape (139, 512, 512)\n", + "./SABS/tmp_normalized/image_1.nii.gz has been save\n", + "./SABS/tmp_normalized/label_1.nii.gz has been save\n", + "(198, 512, 512) label shape (198, 512, 512)\n", + "./SABS/tmp_normalized/image_2.nii.gz has been save\n", + "./SABS/tmp_normalized/label_2.nii.gz has been save\n", + "(140, 512, 512) label shape (140, 512, 512)\n", + "./SABS/tmp_normalized/image_3.nii.gz has been save\n", + "./SABS/tmp_normalized/label_3.nii.gz has been save\n", + "(117, 512, 512) label shape (117, 512, 512)\n", + "./SABS/tmp_normalized/image_4.nii.gz has been save\n", + "./SABS/tmp_normalized/label_4.nii.gz has been save\n", + "(131, 512, 512) label shape (131, 512, 512)\n", + "./SABS/tmp_normalized/image_5.nii.gz has been save\n", + "./SABS/tmp_normalized/label_5.nii.gz has been save\n", + "(163, 512, 512) label shape (163, 512, 512)\n", + "./SABS/tmp_normalized/image_6.nii.gz has been save\n", + "./SABS/tmp_normalized/label_6.nii.gz has been save\n", + "(148, 512, 512) label shape (148, 512, 512)\n", + "./SABS/tmp_normalized/image_7.nii.gz has been save\n", + "./SABS/tmp_normalized/label_7.nii.gz has been save\n", + "(149, 512, 512) label shape (149, 512, 512)\n", + "./SABS/tmp_normalized/image_8.nii.gz has been save\n", + "./SABS/tmp_normalized/label_8.nii.gz has been save\n", + "(148, 512, 512) label shape (148, 512, 512)\n", + "./SABS/tmp_normalized/image_9.nii.gz has been save\n", + "./SABS/tmp_normalized/label_9.nii.gz has been save\n", + "(143, 512, 512) label shape (143, 512, 512)\n", + "./SABS/tmp_normalized/image_10.nii.gz has been save\n", + "./SABS/tmp_normalized/label_10.nii.gz has been save\n", + "(89, 512, 512) label shape (89, 512, 512)\n", + "./SABS/tmp_normalized/image_11.nii.gz has been save\n", + "./SABS/tmp_normalized/label_11.nii.gz has been save\n", + "(96, 512, 512) label shape (96, 512, 512)\n", + "./SABS/tmp_normalized/image_12.nii.gz has been save\n", + "./SABS/tmp_normalized/label_12.nii.gz has been save\n", + "(124, 512, 512) label shape (124, 512, 512)\n", + "./SABS/tmp_normalized/image_13.nii.gz has been save\n", + "./SABS/tmp_normalized/label_13.nii.gz has been save\n", + "(85, 512, 512) label shape (85, 512, 512)\n", + "./SABS/tmp_normalized/image_14.nii.gz has been save\n", + "./SABS/tmp_normalized/label_14.nii.gz has been save\n", + "(131, 512, 512) label shape (131, 512, 512)\n", + "./SABS/tmp_normalized/image_15.nii.gz has been save\n", + "./SABS/tmp_normalized/label_15.nii.gz has been save\n", + "(88, 512, 512) label shape (88, 512, 512)\n", + "./SABS/tmp_normalized/image_16.nii.gz has been save\n", + "./SABS/tmp_normalized/label_16.nii.gz has been save\n", + "(89, 512, 512) label shape (89, 512, 512)\n", + "./SABS/tmp_normalized/image_17.nii.gz has been save\n", + "./SABS/tmp_normalized/label_17.nii.gz has been save\n", + "(100, 512, 512) label shape (100, 512, 512)\n", + "./SABS/tmp_normalized/image_18.nii.gz has been save\n", + "./SABS/tmp_normalized/label_18.nii.gz has been save\n", + "(153, 512, 512) label shape (153, 512, 512)\n", + "./SABS/tmp_normalized/image_19.nii.gz has been save\n", + "./SABS/tmp_normalized/label_19.nii.gz has been save\n", + "(93, 512, 512) label shape (93, 512, 512)\n", + "./SABS/tmp_normalized/image_20.nii.gz has been save\n", + "./SABS/tmp_normalized/label_20.nii.gz has been save\n", + "(144, 512, 512) label shape (144, 512, 512)\n", + "./SABS/tmp_normalized/image_21.nii.gz has been save\n", + "./SABS/tmp_normalized/label_21.nii.gz has been save\n", + "(104, 512, 512) label shape (104, 512, 512)\n", + "./SABS/tmp_normalized/image_22.nii.gz has been save\n", + "./SABS/tmp_normalized/label_22.nii.gz has been save\n", + "(98, 512, 512) label shape (98, 512, 512)\n", + "./SABS/tmp_normalized/image_23.nii.gz has been save\n", + "./SABS/tmp_normalized/label_23.nii.gz has been save\n", + "(94, 512, 512) label shape (94, 512, 512)\n", + "./SABS/tmp_normalized/image_24.nii.gz has been save\n", + "./SABS/tmp_normalized/label_24.nii.gz has been save\n", + "(184, 512, 512) label shape (184, 512, 512)\n", + "./SABS/tmp_normalized/image_25.nii.gz has been save\n", + "./SABS/tmp_normalized/label_25.nii.gz has been save\n", + "(99, 512, 512) label shape (99, 512, 512)\n", + "./SABS/tmp_normalized/image_26.nii.gz has been save\n", + "./SABS/tmp_normalized/label_26.nii.gz has been save\n", + "(100, 512, 512) label shape (100, 512, 512)\n", + "./SABS/tmp_normalized/image_27.nii.gz has been save\n", + "./SABS/tmp_normalized/label_27.nii.gz has been save\n", + "(90, 512, 512) label shape (90, 512, 512)\n", + "./SABS/tmp_normalized/image_28.nii.gz has been save\n", + "./SABS/tmp_normalized/label_28.nii.gz has been save\n", + "(195, 512, 512) label shape (195, 512, 512)\n", + "./SABS/tmp_normalized/image_29.nii.gz has been save\n", + "./SABS/tmp_normalized/label_29.nii.gz has been save\n" + ] + } + ], + "source": [ + "import copy\n", + "scan_dir = OUT_FOLDER\n", + "LIR = -125\n", + "HIR = 275\n", + "os.makedirs(scan_dir, exist_ok = True)\n", + "\n", + "reindex = 0\n", + "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", + "\n", + " img_obj = sitk.ReadImage( img_fid )\n", + " seg_obj = sitk.ReadImage( seg_fid )\n", + "\n", + " array = sitk.GetArrayFromImage(img_obj)\n", + " print(array.shape, f\"label shape {sitk.GetArrayFromImage(seg_obj).shape}\")\n", + " array[array > HIR] = HIR\n", + " array[array < LIR] = LIR\n", + " \n", + " array = (array - array.min()) / (array.max() - array.min()) * 255.0\n", + " \n", + " # then normalize this\n", + " \n", + " wined_img = sitk.GetImageFromArray(array)\n", + " wined_img = copy_spacing_ori(img_obj, wined_img)\n", + " \n", + " out_img_fid = os.path.join( scan_dir, f'image_{str(reindex)}.nii.gz' )\n", + " out_lb_fid = os.path.join( scan_dir, f'label_{str(reindex)}.nii.gz' ) \n", + " \n", + " # then save\n", + " sitk.WriteImage(wined_img, out_img_fid, True) \n", + " sitk.WriteImage(seg_obj, out_lb_fid, True) \n", + " print(\"{} has been save\".format(out_img_fid))\n", + " print(\"{} has been save\".format(out_lb_fid))\n", + " reindex += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Overview\n", + "\n", + "This is the second step of preprocessing\n", + "\n", + "Cut out irrelevant empty boundary and resample to 512x512 in axial plane.\n", + "\n", + "Input: intensity-normalized images\n", + "\n", + "Output: spacially resampled images" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['0', '1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '4', '5', '6', '7', '8', '9']\n", + "./SABS/tmp_normalized/image_0.nii.gz ./SABS/tmp_normalized/label_0.nii.gz\n", + "./SABS/tmp_normalized/image_1.nii.gz ./SABS/tmp_normalized/label_1.nii.gz\n", + "./SABS/tmp_normalized/image_10.nii.gz ./SABS/tmp_normalized/label_10.nii.gz\n", + "./SABS/tmp_normalized/image_11.nii.gz ./SABS/tmp_normalized/label_11.nii.gz\n", + "./SABS/tmp_normalized/image_12.nii.gz ./SABS/tmp_normalized/label_12.nii.gz\n", + "./SABS/tmp_normalized/image_13.nii.gz ./SABS/tmp_normalized/label_13.nii.gz\n", + "./SABS/tmp_normalized/image_14.nii.gz ./SABS/tmp_normalized/label_14.nii.gz\n", + "./SABS/tmp_normalized/image_15.nii.gz ./SABS/tmp_normalized/label_15.nii.gz\n", + "./SABS/tmp_normalized/image_16.nii.gz ./SABS/tmp_normalized/label_16.nii.gz\n", + "./SABS/tmp_normalized/image_17.nii.gz ./SABS/tmp_normalized/label_17.nii.gz\n", + "./SABS/tmp_normalized/image_18.nii.gz ./SABS/tmp_normalized/label_18.nii.gz\n", + "./SABS/tmp_normalized/image_19.nii.gz ./SABS/tmp_normalized/label_19.nii.gz\n", + "./SABS/tmp_normalized/image_2.nii.gz ./SABS/tmp_normalized/label_2.nii.gz\n", + "./SABS/tmp_normalized/image_20.nii.gz ./SABS/tmp_normalized/label_20.nii.gz\n", + "./SABS/tmp_normalized/image_21.nii.gz ./SABS/tmp_normalized/label_21.nii.gz\n", + "./SABS/tmp_normalized/image_22.nii.gz ./SABS/tmp_normalized/label_22.nii.gz\n", + "./SABS/tmp_normalized/image_23.nii.gz ./SABS/tmp_normalized/label_23.nii.gz\n", + "./SABS/tmp_normalized/image_24.nii.gz ./SABS/tmp_normalized/label_24.nii.gz\n", + "./SABS/tmp_normalized/image_25.nii.gz ./SABS/tmp_normalized/label_25.nii.gz\n", + "./SABS/tmp_normalized/image_26.nii.gz ./SABS/tmp_normalized/label_26.nii.gz\n", + "./SABS/tmp_normalized/image_27.nii.gz ./SABS/tmp_normalized/label_27.nii.gz\n", + "./SABS/tmp_normalized/image_28.nii.gz ./SABS/tmp_normalized/label_28.nii.gz\n", + "./SABS/tmp_normalized/image_29.nii.gz ./SABS/tmp_normalized/label_29.nii.gz\n", + "./SABS/tmp_normalized/image_3.nii.gz ./SABS/tmp_normalized/label_3.nii.gz\n", + "./SABS/tmp_normalized/image_4.nii.gz ./SABS/tmp_normalized/label_4.nii.gz\n", + "./SABS/tmp_normalized/image_5.nii.gz ./SABS/tmp_normalized/label_5.nii.gz\n", + "./SABS/tmp_normalized/image_6.nii.gz ./SABS/tmp_normalized/label_6.nii.gz\n", + "./SABS/tmp_normalized/image_7.nii.gz ./SABS/tmp_normalized/label_7.nii.gz\n", + "./SABS/tmp_normalized/image_8.nii.gz ./SABS/tmp_normalized/label_8.nii.gz\n", + "./SABS/tmp_normalized/image_9.nii.gz ./SABS/tmp_normalized/label_9.nii.gz\n" + ] + } + ], + "source": [ + "IMG_FOLDER = \"./SABS/tmp_normalized\"\n", + "\n", + "SEG_FOLDER = IMG_FOLDER\n", + "imgs = glob.glob(IMG_FOLDER + \"/image_*.nii.gz\")\n", + "imgs = sorted([ fid for fid in sorted(imgs) ])\n", + "segs = sorted([ fid for fid in glob.glob(SEG_FOLDER + \"/label_*.nii.gz\")])\n", + "\n", + "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]\n", + "print(pids)\n", + "for img, seg in zip(imgs, segs):\n", + " print(img, seg)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "./SABS/tmp_normalized/image_0.nii.gz ./SABS/tmp_normalized/label_0.nii.gz\n", + "(147, 512, 512) label shape (147, 512, 512)\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]\n", + "Size (448, 448, 147) -> [448. 448. 147.]\n", + "./SABS/sabs_CT_normalized/image_0.nii.gz has been saved, shape: (449, 449, 148)\n", + "./SABS/sabs_CT_normalized/label_0.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_1.nii.gz ./SABS/tmp_normalized/label_1.nii.gz\n", + "(139, 512, 512) label shape (139, 512, 512)\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "Spacing: (0.720703125, 0.720703125, 3.0) -> [0.720703125, 0.720703125, 3.0]\n", + "Size (448, 448, 139) -> [448. 448. 139.]\n", + "./SABS/sabs_CT_normalized/image_1.nii.gz has been saved, shape: (449, 449, 140)\n", + "./SABS/sabs_CT_normalized/label_1.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_10.nii.gz ./SABS/tmp_normalized/label_10.nii.gz\n", + "(143, 512, 512) label shape (143, 512, 512)\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "Spacing: (0.68359375, 0.68359375, 3.0) -> [0.68359375, 0.68359375, 3.0]\n", + "Size (448, 448, 143) -> [448. 448. 143.]\n", + "./SABS/sabs_CT_normalized/image_10.nii.gz has been saved, shape: (449, 449, 144)\n", + "./SABS/sabs_CT_normalized/label_10.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_11.nii.gz ./SABS/tmp_normalized/label_11.nii.gz\n", + "(89, 512, 512) label shape (89, 512, 512)\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.7675780057907104, 0.7675780057907104, 5.0) -> [0.7675780057907104, 0.7675780057907104, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "./SABS/sabs_CT_normalized/image_11.nii.gz has been saved, shape: (449, 449, 90)\n", + "./SABS/sabs_CT_normalized/label_11.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_12.nii.gz ./SABS/tmp_normalized/label_12.nii.gz\n", + "(96, 512, 512) label shape (96, 512, 512)\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Label values: [ 0 1 2 3 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "Spacing: (0.70703125, 0.70703125, 5.0) -> [0.70703125, 0.70703125, 5.0]\n", + "Size (448, 448, 96) -> [448. 448. 96.]\n", + "./SABS/sabs_CT_normalized/image_12.nii.gz has been saved, shape: (449, 449, 97)\n", + "./SABS/sabs_CT_normalized/label_12.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_13.nii.gz ./SABS/tmp_normalized/label_13.nii.gz\n", + "(124, 512, 512) label shape (124, 512, 512)\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "Spacing: (0.685546875, 0.685546875, 3.0) -> [0.685546875, 0.685546875, 3.0]\n", + "Size (448, 448, 124) -> [448. 448. 124.]\n", + "./SABS/sabs_CT_normalized/image_13.nii.gz has been saved, shape: (449, 449, 125)\n", + "./SABS/sabs_CT_normalized/label_13.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_14.nii.gz ./SABS/tmp_normalized/label_14.nii.gz\n", + "(85, 512, 512) label shape (85, 512, 512)\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "Spacing: (0.83203125, 0.83203125, 5.0) -> [0.83203125, 0.83203125, 5.0]\n", + "Size (448, 448, 85) -> [448. 448. 85.]\n", + "./SABS/sabs_CT_normalized/image_14.nii.gz has been saved, shape: (449, 449, 86)\n", + "./SABS/sabs_CT_normalized/label_14.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_15.nii.gz ./SABS/tmp_normalized/label_15.nii.gz\n", + "(131, 512, 512) label shape (131, 512, 512)\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.7792969942092896, 0.7792969942092896, 5.0) -> [0.7792969942092896, 0.7792969942092896, 5.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "./SABS/sabs_CT_normalized/image_15.nii.gz has been saved, shape: (449, 449, 132)\n", + "./SABS/sabs_CT_normalized/label_15.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_16.nii.gz ./SABS/tmp_normalized/label_16.nii.gz\n", + "(88, 512, 512) label shape (88, 512, 512)\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "Spacing: (0.7753909826278687, 0.7753909826278687, 5.0) -> [0.7753909826278687, 0.7753909826278687, 5.0]\n", + "Size (448, 448, 88) -> [448. 448. 88.]\n", + "./SABS/sabs_CT_normalized/image_16.nii.gz has been saved, shape: (449, 449, 89)\n", + "./SABS/sabs_CT_normalized/label_16.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_17.nii.gz ./SABS/tmp_normalized/label_17.nii.gz\n", + "(89, 512, 512) label shape (89, 512, 512)\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "Spacing: (0.796875, 0.796875, 5.0) -> [0.796875, 0.796875, 5.0]\n", + "Size (448, 448, 89) -> [448. 448. 89.]\n", + "./SABS/sabs_CT_normalized/image_17.nii.gz has been saved, shape: (449, 449, 90)\n", + "./SABS/sabs_CT_normalized/label_17.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_18.nii.gz ./SABS/tmp_normalized/label_18.nii.gz\n", + "(100, 512, 512) label shape (100, 512, 512)\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.875, 0.875, 3.0) -> [0.875, 0.875, 3.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "./SABS/sabs_CT_normalized/image_18.nii.gz has been saved, shape: (449, 449, 101)\n", + "./SABS/sabs_CT_normalized/label_18.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_19.nii.gz ./SABS/tmp_normalized/label_19.nii.gz\n", + "(153, 512, 512) label shape (153, 512, 512)\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 153) -> [448. 448. 153.]\n", + "./SABS/sabs_CT_normalized/image_19.nii.gz has been saved, shape: (449, 449, 154)\n", + "./SABS/sabs_CT_normalized/label_19.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_2.nii.gz ./SABS/tmp_normalized/label_2.nii.gz\n", + "(198, 512, 512) label shape (198, 512, 512)\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "Spacing: (0.8984375, 0.8984375, 3.0) -> [0.8984375, 0.8984375, 3.0]\n", + "Size (448, 448, 198) -> [448. 448. 198.]\n", + "./SABS/sabs_CT_normalized/image_2.nii.gz has been saved, shape: (449, 449, 199)\n", + "./SABS/sabs_CT_normalized/label_2.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_20.nii.gz ./SABS/tmp_normalized/label_20.nii.gz\n", + "(93, 512, 512) label shape (93, 512, 512)\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "Spacing: (0.837890625, 0.837890625, 3.0) -> [0.837890625, 0.837890625, 3.0]\n", + "Size (448, 448, 93) -> [448. 448. 93.]\n", + "./SABS/sabs_CT_normalized/image_20.nii.gz has been saved, shape: (449, 449, 94)\n", + "./SABS/sabs_CT_normalized/label_20.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_21.nii.gz ./SABS/tmp_normalized/label_21.nii.gz\n", + "(144, 512, 512) label shape (144, 512, 512)\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "Spacing: (0.740234375, 0.740234375, 3.0) -> [0.740234375, 0.740234375, 3.0]\n", + "Size (448, 448, 144) -> [448. 448. 144.]\n", + "./SABS/sabs_CT_normalized/image_21.nii.gz has been saved, shape: (449, 449, 145)\n", + "./SABS/sabs_CT_normalized/label_21.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_22.nii.gz ./SABS/tmp_normalized/label_22.nii.gz\n", + "(104, 512, 512) label shape (104, 512, 512)\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "Spacing: (0.8144530057907104, 0.8144530057907104, 5.0) -> [0.8144530057907104, 0.8144530057907104, 5.0]\n", + "Size (448, 448, 104) -> [448. 448. 104.]\n", + "./SABS/sabs_CT_normalized/image_22.nii.gz has been saved, shape: (449, 449, 105)\n", + "./SABS/sabs_CT_normalized/label_22.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_23.nii.gz ./SABS/tmp_normalized/label_23.nii.gz\n", + "(98, 512, 512) label shape (98, 512, 512)\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "Spacing: (0.671875, 0.671875, 5.0) -> [0.671875, 0.671875, 5.0]\n", + "Size (448, 448, 98) -> [448. 448. 98.]\n", + "./SABS/sabs_CT_normalized/image_23.nii.gz has been saved, shape: (449, 449, 99)\n", + "./SABS/sabs_CT_normalized/label_23.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_24.nii.gz ./SABS/tmp_normalized/label_24.nii.gz\n", + "(94, 512, 512) label shape (94, 512, 512)\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Label values: [ 0 1 2 3 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "Spacing: (0.9199219942092896, 0.9199219942092896, 5.0) -> [0.9199219942092896, 0.9199219942092896, 5.0]\n", + "Size (448, 448, 94) -> [448. 448. 94.]\n", + "./SABS/sabs_CT_normalized/image_24.nii.gz has been saved, shape: (449, 449, 95)\n", + "./SABS/sabs_CT_normalized/label_24.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_25.nii.gz ./SABS/tmp_normalized/label_25.nii.gz\n", + "(184, 512, 512) label shape (184, 512, 512)\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "Spacing: (0.74609375, 0.74609375, 3.0) -> [0.74609375, 0.74609375, 3.0]\n", + "Size (448, 448, 184) -> [448. 448. 184.]\n", + "./SABS/sabs_CT_normalized/image_25.nii.gz has been saved, shape: (449, 449, 185)\n", + "./SABS/sabs_CT_normalized/label_25.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_26.nii.gz ./SABS/tmp_normalized/label_26.nii.gz\n", + "(99, 512, 512) label shape (99, 512, 512)\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "Spacing: (0.703125, 0.703125, 5.0) -> [0.703125, 0.703125, 5.0]\n", + "Size (448, 448, 99) -> [448. 448. 99.]\n", + "./SABS/sabs_CT_normalized/image_26.nii.gz has been saved, shape: (449, 449, 100)\n", + "./SABS/sabs_CT_normalized/label_26.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_27.nii.gz ./SABS/tmp_normalized/label_27.nii.gz\n", + "(100, 512, 512) label shape (100, 512, 512)\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 5.0) -> [0.9765620231628418, 0.9765620231628418, 5.0]\n", + "Size (448, 448, 100) -> [448. 448. 100.]\n", + "./SABS/sabs_CT_normalized/image_27.nii.gz has been saved, shape: (449, 449, 101)\n", + "./SABS/sabs_CT_normalized/label_27.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_28.nii.gz ./SABS/tmp_normalized/label_28.nii.gz\n", + "(90, 512, 512) label shape (90, 512, 512)\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "Spacing: (0.9760000109672546, 0.9760000109672546, 5.0) -> [0.9760000109672546, 0.9760000109672546, 5.0]\n", + "Size (448, 448, 90) -> [448. 448. 90.]\n", + "./SABS/sabs_CT_normalized/image_28.nii.gz has been saved, shape: (449, 449, 91)\n", + "./SABS/sabs_CT_normalized/label_28.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_29.nii.gz ./SABS/tmp_normalized/label_29.nii.gz\n", + "(195, 512, 512) label shape (195, 512, 512)\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "Spacing: (0.7421875, 0.7421875, 3.0) -> [0.7421875, 0.7421875, 3.0]\n", + "Size (448, 448, 195) -> [448. 448. 195.]\n", + "./SABS/sabs_CT_normalized/image_29.nii.gz has been saved, shape: (449, 449, 196)\n", + "./SABS/sabs_CT_normalized/label_29.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_3.nii.gz ./SABS/tmp_normalized/label_3.nii.gz\n", + "(140, 512, 512) label shape (140, 512, 512)\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "Spacing: (0.59375, 0.59375, 3.0) -> [0.59375, 0.59375, 3.0]\n", + "Size (448, 448, 140) -> [448. 448. 140.]\n", + "./SABS/sabs_CT_normalized/image_3.nii.gz has been saved, shape: (449, 449, 141)\n", + "./SABS/sabs_CT_normalized/label_3.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_4.nii.gz ./SABS/tmp_normalized/label_4.nii.gz\n", + "(117, 512, 512) label shape (117, 512, 512)\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "Spacing: (0.90625, 0.90625, 3.0) -> [0.90625, 0.90625, 3.0]\n", + "Size (448, 448, 117) -> [448. 448. 117.]\n", + "./SABS/sabs_CT_normalized/image_4.nii.gz has been saved, shape: (449, 449, 118)\n", + "./SABS/sabs_CT_normalized/label_4.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_5.nii.gz ./SABS/tmp_normalized/label_5.nii.gz\n", + "(131, 512, 512) label shape (131, 512, 512)\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "Spacing: (0.701171875, 0.701171875, 3.0) -> [0.701171875, 0.701171875, 3.0]\n", + "Size (448, 448, 131) -> [448. 448. 131.]\n", + "./SABS/sabs_CT_normalized/image_5.nii.gz has been saved, shape: (449, 449, 132)\n", + "./SABS/sabs_CT_normalized/label_5.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_6.nii.gz ./SABS/tmp_normalized/label_6.nii.gz\n", + "(163, 512, 512) label shape (163, 512, 512)\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "Spacing: (0.748046875, 0.748046875, 3.0) -> [0.748046875, 0.748046875, 3.0]\n", + "Size (448, 448, 163) -> [448. 448. 163.]\n", + "./SABS/sabs_CT_normalized/image_6.nii.gz has been saved, shape: (449, 449, 164)\n", + "./SABS/sabs_CT_normalized/label_6.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_7.nii.gz ./SABS/tmp_normalized/label_7.nii.gz\n", + "(148, 512, 512) label shape (148, 512, 512)\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.728515625, 0.728515625, 3.0) -> [0.728515625, 0.728515625, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "./SABS/sabs_CT_normalized/image_7.nii.gz has been saved, shape: (449, 449, 149)\n", + "./SABS/sabs_CT_normalized/label_7.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_8.nii.gz ./SABS/tmp_normalized/label_8.nii.gz\n", + "(149, 512, 512) label shape (149, 512, 512)\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "Spacing: (0.9765620231628418, 0.9765620231628418, 2.5) -> [0.9765620231628418, 0.9765620231628418, 2.5]\n", + "Size (448, 448, 149) -> [448. 448. 149.]\n", + "./SABS/sabs_CT_normalized/image_8.nii.gz has been saved, shape: (449, 449, 150)\n", + "./SABS/sabs_CT_normalized/label_8.nii.gz has been saved\n", + "./SABS/tmp_normalized/image_9.nii.gz ./SABS/tmp_normalized/label_9.nii.gz\n", + "(148, 512, 512) label shape (148, 512, 512)\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "Spacing: (0.78125, 0.78125, 3.0) -> [0.78125, 0.78125, 3.0]\n", + "Size (448, 448, 148) -> [448. 448. 148.]\n", + "./SABS/sabs_CT_normalized/image_9.nii.gz has been saved, shape: (449, 449, 149)\n", + "./SABS/sabs_CT_normalized/label_9.nii.gz has been saved\n" + ] + } + ], + "source": [ + "import copy\n", + "OUT_FOLDER = \"./SABS/sabs_CT_normalized\"\n", + "BD_BIAS = 32 # cut irrelavent empty boundary to make roi stands out\n", + "\n", + "# SPA_FAC = (512 - 2 * BD_BIAS) / 512 # spacing factor\n", + "for res in (256, 672):\n", + " if res == 672:\n", + " OUT_FOLDER += \"_672\"\n", + " scan_dir = OUT_FOLDER\n", + " os.makedirs(OUT_FOLDER, exist_ok = True)\n", + "\n", + " resample_imgs(imgs, segs, pids, scan_dir, BD_BIAS, SPA_FAC=None, required_res=res)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Synapse Classmap Generation" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pid 0 finished!\n", + "pid 1 finished!\n", + "pid 2 finished!\n", + "pid 3 finished!\n", + "pid 4 finished!\n", + "pid 5 finished!\n", + "pid 6 finished!\n", + "pid 7 finished!\n", + "pid 8 finished!\n", + "pid 9 finished!\n", + "pid 10 finished!\n", + "pid 11 finished!\n", + "pid 12 finished!\n", + "pid 13 finished!\n", + "pid 14 finished!\n", + "pid 15 finished!\n", + "pid 16 finished!\n", + "pid 17 finished!\n", + "pid 18 finished!\n", + "pid 19 finished!\n", + "pid 20 finished!\n", + "pid 21 finished!\n", + "pid 22 finished!\n", + "pid 23 finished!\n", + "pid 24 finished!\n", + "pid 25 finished!\n", + "pid 26 finished!\n", + "pid 27 finished!\n", + "pid 28 finished!\n", + "pid 29 finished!\n", + "pid 30 finished!\n", + "pid 31 finished!\n", + "pid 32 finished!\n", + "pid 33 finished!\n", + "pid 34 finished!\n", + "pid 35 finished!\n", + "pid 36 finished!\n", + "pid 37 finished!\n" + ] + } + ], + "source": [ + "import json\n", + "# import niftiio as nio\n", + "import SimpleITK as sitk\n", + "\n", + "# normalization: cut top 2% of histogram, then doing volume-wise normalization\n", + "IMG_BNAMES = (\"./SABS/sabs_CT_normalized/image_*.nii.gz\", \"./SABS/sabs_CT_normalized_672/image_*.nii.gz\")\n", + "SEG_NAMES = (\"./SABS/sabs_CT_normalized/label_*.nii.gz\", \"./SABS/sabs_CT_normalized_672/label_*.nii.gz\")\n", + "for IMG_BNAME, SEG_BNAME in zip(IMG_BNAMES, SEG_NAMES):\n", + " imgs = glob.glob(IMG_BNAME)\n", + " segs = glob.glob(SEG_BNAME)\n", + " imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", + " segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", + " for img, seg in zip(imgs, segs):\n", + " print(img, seg)\n", + "\n", + " classmap = {}\n", + " LABEL_NAME = [\"BGD\", \"SPLEEN\", \"KID_R\", \"KID_l\", \"GALLBLADDER\", \"ESOPHAGUS\", \"LIVER\", \"STOMACH\", \"AORTA\", \"IVC\", \"PS_VEIN\", \"PANCREAS\", \"AG_R\", \"AG_L\"] \n", + "\n", + " MIN_TP=1 # minimum number of true positive pixels in a slice\n", + "\n", + " fid = os.path.dirname(IMG_BNAME) + f'/classmap_{MIN_TP}.json'\n", + " for _lb in LABEL_NAME:\n", + " classmap[_lb] = {}\n", + " for pid in range(len(segs)):\n", + " classmap[_lb][str(pid)] = []\n", + "\n", + " for pid, seg in enumerate(segs):\n", + " # lb_vol = nio.read_nii_bysitk(seg)\n", + " lb_vol = sitk.GetArrayFromImage(sitk.ReadImage(seg))\n", + " n_slice = lb_vol.shape[0]\n", + " for slc in range(n_slice):\n", + " for cls in range(len(LABEL_NAME)):\n", + " if cls in lb_vol[slc, ...]:\n", + " if np.sum( lb_vol[slc, ...] == cls) >= MIN_TP:\n", + " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", + " print(f'pid {str(pid)} finished!')\n", + " \n", + " with open(fid, 'w') as fopen:\n", + " json.dump(classmap, fopen)\n", + " fopen.close() \n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MRI Image Normalization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## PLEASE RUN dcm_img_to_nii.sh to convert dicom to nii.gz\n", + "! ./dcm_img_to_nii.sh" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [316. 316. 35.99999911]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [316. 316. 35.99999911]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [316. 316. 35.99999911]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [316. 316. 35.99999911]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [316. 316. 35.99999911]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [316. 316. 35.99999911]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_1.nii.gz has been saved\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [348. 348. 35.99999911]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [348. 348. 35.99999911]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [348. 348. 35.99999911]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [348. 348. 35.99999911]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [348. 348. 35.99999911]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 36) -> [348. 348. 35.99999911]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_10.nii.gz has been saved\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_13.nii.gz has been saved\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_15.nii.gz has been saved\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_19.nii.gz has been saved\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_2.nii.gz has been saved\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 26) -> [348. 348. 30.38961039]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_20.nii.gz has been saved\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 31) -> [340. 340. 32.20779221]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 31) -> [340. 340. 32.20779221]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 31) -> [340. 340. 32.20779221]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 31) -> [340. 340. 32.20779221]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 31) -> [340. 340. 32.20779221]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 31) -> [340. 340. 32.20779221]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_21.nii.gz has been saved\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 33) -> [356. 356. 37.28571347]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 33) -> [356. 356. 37.28571347]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 33) -> [356. 356. 37.28571347]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 33) -> [356. 356. 37.28571347]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 33) -> [356. 356. 37.28571347]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 33) -> [356. 356. 37.28571347]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_22.nii.gz has been saved\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 30) -> [348. 348. 35.06493506]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_3.nii.gz has been saved\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [332. 332. 33.8961039]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [332. 332. 33.8961039]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [332. 332. 33.8961039]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [332. 332. 33.8961039]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [332. 332. 33.8961039]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [332. 332. 33.8961039]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_31.nii.gz has been saved\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 39) -> [372. 372. 45.58441558]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 39) -> [372. 372. 45.58441558]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 39) -> [372. 372. 45.58441558]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 39) -> [372. 372. 45.58441558]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 39) -> [372. 372. 45.58441558]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 39) -> [372. 372. 45.58441558]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_32.nii.gz has been saved\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [356. 356. 33.8961039]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [356. 356. 33.8961039]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [356. 356. 33.8961039]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [356. 356. 33.8961039]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [356. 356. 33.8961039]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 29) -> [356. 356. 33.8961039]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_33.nii.gz has been saved\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [356. 356. 34.28571503]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [356. 356. 34.28571503]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [356. 356. 34.28571503]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [356. 356. 34.28571503]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [356. 356. 34.28571503]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [356. 356. 34.28571503]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_34.nii.gz has been saved\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 38) -> [332. 332. 44.41558442]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 38) -> [332. 332. 44.41558442]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 38) -> [332. 332. 44.41558442]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 38) -> [332. 332. 44.41558442]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 38) -> [332. 332. 44.41558442]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 38) -> [332. 332. 44.41558442]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_36.nii.gz has been saved\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 32) -> [300. 300. 37.4025974]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 32) -> [300. 300. 37.4025974]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 32) -> [300. 300. 37.4025974]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 32) -> [300. 300. 37.4025974]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 32) -> [300. 300. 37.4025974]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 32) -> [300. 300. 37.4025974]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_37.nii.gz has been saved\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 34) -> [348. 348. 39.74025974]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 34) -> [348. 348. 39.74025974]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 34) -> [348. 348. 39.74025974]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 34) -> [348. 348. 39.74025974]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 34) -> [348. 348. 39.74025974]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (320, 320, 34) -> [348. 348. 39.74025974]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_38.nii.gz has been saved\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 26) -> [324. 324. 30.38961039]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_39.nii.gz has been saved\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [340. 340. 35.06493506]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [340. 340. 35.06493506]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [340. 340. 35.06493506]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [340. 340. 35.06493506]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [340. 340. 35.06493506]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [1.25, 1.25, 7.7]\n", + "Size (256, 256, 30) -> [340. 340. 35.06493506]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_5.nii.gz has been saved\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (288, 288, 32) -> [324. 324. 33.24675325]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (288, 288, 32) -> [324. 324. 33.24675325]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (288, 288, 32) -> [324. 324. 33.24675325]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (288, 288, 32) -> [324. 324. 33.24675325]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (288, 288, 32) -> [324. 324. 33.24675325]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [1.25, 1.25, 7.7]\n", + "Size (288, 288, 32) -> [324. 324. 33.24675325]\n", + "./CHAOST2/chaos_MR_T2_normalized/image_8.nii.gz has been saved\n" + ] + } + ], + "source": [ + "import copy\n", + "\n", + "IMG_FOLDER = \"./CHAOST2/niis/T2SPIR\" #, path of nii-like images from step 1\n", + "OUT_FOLDER=\"./CHAOST2/chaos_MR_T2_normalized/\" # output directory\n", + "\n", + "imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')\n", + "imgs = [ fid for fid in sorted(imgs) ]\n", + "segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]\n", + "\n", + "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]\n", + "for img, seg in zip(imgs, segs):\n", + " print(img, seg)\n", + "\n", + "os.makedirs(OUT_FOLDER, exist_ok = True)\n", + " \n", + "HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect\n", + "\n", + "NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing\n", + "\n", + "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", + "\n", + " resample_flg = True\n", + "\n", + " img_obj = sitk.ReadImage( img_fid )\n", + " seg_obj = sitk.ReadImage( seg_fid )\n", + "\n", + " array = sitk.GetArrayFromImage(img_obj)\n", + "\n", + " # cut histogram\n", + " hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))\n", + " array[array > hir] = hir\n", + "\n", + " his_img_o = sitk.GetImageFromArray(array)\n", + " his_img_o = copy_spacing_ori(img_obj, his_img_o)\n", + "\n", + " # resampling\n", + " img_spa_ori = img_obj.GetSpacing()\n", + " res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],\n", + " interpolator = sitk.sitkLinear, logging = True)\n", + " ## label\n", + " lb_arr = sitk.GetArrayFromImage(seg_obj)\n", + "\n", + " # resampling\n", + " res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,\n", + " ref_img = None, logging = True)\n", + "\n", + " # crop out rois\n", + " res_img_a = s2n(res_img_o)\n", + "\n", + " crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],\n", + " referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],\n", + " padval = res_img_a.min(), only_2d = True).transpose(2,0,1)\n", + "\n", + " out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))\n", + "\n", + " res_lb_a = s2n(res_lb_o)\n", + "\n", + " crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],\n", + " referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],\n", + " padval = 0, only_2d = True).transpose(2,0,1)\n", + "\n", + " out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))\n", + "\n", + "\n", + " out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )\n", + " out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) \n", + "\n", + " # then save pre-processed images\n", + " sitk.WriteImage(out_img_obj, out_img_fid, True) \n", + " sitk.WriteImage(out_lb_obj, out_lb_fid, True) \n", + " print(\"{} has been saved\".format(out_img_fid))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MRI Resampling and ROI" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "./CHAOST2/niis/T2SPIR/image_1.nii.gz ./CHAOST2/niis/T2SPIR/label_1.nii.gz\n", + "(36, 256, 256) label shape (36, 256, 256)\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "./SABS/sabs_CT_normalized/image_1.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_1.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_10.nii.gz ./CHAOST2/niis/T2SPIR/label_10.nii.gz\n", + "(36, 256, 256) label shape (36, 256, 256)\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [0.6422642299107143, 0.6422642299107143, 7.699999809265137]\n", + "Size (254, 254, 36) -> [672. 672. 36.]\n", + "./SABS/sabs_CT_normalized/image_10.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_10.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_13.nii.gz ./CHAOST2/niis/T2SPIR/label_13.nii.gz\n", + "(30, 320, 320) label shape (30, 320, 320)\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "./SABS/sabs_CT_normalized/image_13.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_13.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_15.nii.gz ./CHAOST2/niis/T2SPIR/label_15.nii.gz\n", + "(26, 256, 256) label shape (26, 256, 256)\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "./SABS/sabs_CT_normalized/image_15.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_15.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_19.nii.gz ./CHAOST2/niis/T2SPIR/label_19.nii.gz\n", + "(30, 320, 320) label shape (30, 320, 320)\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "./SABS/sabs_CT_normalized/image_19.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_19.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_2.nii.gz ./CHAOST2/niis/T2SPIR/label_2.nii.gz\n", + "(26, 320, 320) label shape (26, 320, 320)\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "./SABS/sabs_CT_normalized/image_2.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_2.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_20.nii.gz ./CHAOST2/niis/T2SPIR/label_20.nii.gz\n", + "(26, 320, 320) label shape (26, 320, 320)\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 26) -> [672. 672. 26.]\n", + "./SABS/sabs_CT_normalized/image_20.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_20.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_21.nii.gz ./CHAOST2/niis/T2SPIR/label_21.nii.gz\n", + "(31, 256, 256) label shape (31, 256, 256)\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n", + "Size (254, 254, 31) -> [672. 672. 31.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n", + "Size (254, 254, 31) -> [672. 672. 31.]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n", + "Size (254, 254, 31) -> [672. 672. 31.]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n", + "Size (254, 254, 31) -> [672. 672. 31.]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n", + "Size (254, 254, 31) -> [672. 672. 31.]\n", + "Spacing: (1.66015625, 1.66015625, 8.0) -> [0.627499534970238, 0.627499534970238, 8.0]\n", + "Size (254, 254, 31) -> [672. 672. 31.]\n", + "./SABS/sabs_CT_normalized/image_21.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_21.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_22.nii.gz ./CHAOST2/niis/T2SPIR/label_22.nii.gz\n", + "(33, 256, 256) label shape (33, 256, 256)\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n", + "Size (254, 254, 33) -> [672. 672. 33.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n", + "Size (254, 254, 33) -> [672. 672. 33.]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n", + "Size (254, 254, 33) -> [672. 672. 33.]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n", + "Size (254, 254, 33) -> [672. 672. 33.]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n", + "Size (254, 254, 33) -> [672. 672. 33.]\n", + "Spacing: (1.73828125, 1.73828125, 8.699999809265137) -> [0.6570289248511905, 0.6570289248511905, 8.699999809265137]\n", + "Size (254, 254, 33) -> [672. 672. 33.]\n", + "./SABS/sabs_CT_normalized/image_22.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_22.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_3.nii.gz ./CHAOST2/niis/T2SPIR/label_3.nii.gz\n", + "(30, 320, 320) label shape (30, 320, 320)\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 30) -> [672. 672. 30.]\n", + "./SABS/sabs_CT_normalized/image_3.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_3.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_31.nii.gz ./CHAOST2/niis/T2SPIR/label_31.nii.gz\n", + "(29, 256, 256) label shape (29, 256, 256)\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "./SABS/sabs_CT_normalized/image_31.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_31.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_32.nii.gz ./CHAOST2/niis/T2SPIR/label_32.nii.gz\n", + "(39, 256, 256) label shape (39, 256, 256)\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n", + "Size (254, 254, 39) -> [672. 672. 39.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n", + "Size (254, 254, 39) -> [672. 672. 39.]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n", + "Size (254, 254, 39) -> [672. 672. 39.]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n", + "Size (254, 254, 39) -> [672. 672. 39.]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n", + "Size (254, 254, 39) -> [672. 672. 39.]\n", + "Spacing: (1.81640625, 1.81640625, 9.0) -> [0.6865583147321428, 0.6865583147321428, 9.0]\n", + "Size (254, 254, 39) -> [672. 672. 39.]\n", + "./SABS/sabs_CT_normalized/image_32.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_32.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_33.nii.gz ./CHAOST2/niis/T2SPIR/label_33.nii.gz\n", + "(29, 256, 256) label shape (29, 256, 256)\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "Spacing: (1.73828125, 1.73828125, 9.0) -> [0.6570289248511905, 0.6570289248511905, 9.0]\n", + "Size (254, 254, 29) -> [672. 672. 29.]\n", + "./SABS/sabs_CT_normalized/image_33.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_33.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_34.nii.gz ./CHAOST2/niis/T2SPIR/label_34.nii.gz\n", + "(30, 256, 256) label shape (30, 256, 256)\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.73828125, 1.73828125, 8.800000190734863) -> [0.6570289248511905, 0.6570289248511905, 8.800000190734863]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "./SABS/sabs_CT_normalized/image_34.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_34.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_36.nii.gz ./CHAOST2/niis/T2SPIR/label_36.nii.gz\n", + "(38, 256, 256) label shape (38, 256, 256)\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 38) -> [672. 672. 38.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 38) -> [672. 672. 38.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 38) -> [672. 672. 38.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 38) -> [672. 672. 38.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 38) -> [672. 672. 38.]\n", + "Spacing: (1.62109375, 1.62109375, 9.0) -> [0.6127348400297619, 0.6127348400297619, 9.0]\n", + "Size (254, 254, 38) -> [672. 672. 38.]\n", + "./SABS/sabs_CT_normalized/image_36.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_36.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_37.nii.gz ./CHAOST2/niis/T2SPIR/label_37.nii.gz\n", + "(32, 256, 256) label shape (32, 256, 256)\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n", + "Size (254, 254, 32) -> [672. 672. 32.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n", + "Size (254, 254, 32) -> [672. 672. 32.]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n", + "Size (254, 254, 32) -> [672. 672. 32.]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n", + "Size (254, 254, 32) -> [672. 672. 32.]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n", + "Size (254, 254, 32) -> [672. 672. 32.]\n", + "Spacing: (1.46484375, 1.46484375, 9.0) -> [0.5536760602678571, 0.5536760602678571, 9.0]\n", + "Size (254, 254, 32) -> [672. 672. 32.]\n", + "./SABS/sabs_CT_normalized/image_37.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_37.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_38.nii.gz ./CHAOST2/niis/T2SPIR/label_38.nii.gz\n", + "(34, 320, 320) label shape (34, 320, 320)\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 34) -> [672. 672. 34.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 34) -> [672. 672. 34.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 34) -> [672. 672. 34.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 34) -> [672. 672. 34.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 34) -> [672. 672. 34.]\n", + "Spacing: (1.359375, 1.359375, 9.0) -> [0.6432756696428571, 0.6432756696428571, 9.0]\n", + "Size (318, 318, 34) -> [672. 672. 34.]\n", + "./SABS/sabs_CT_normalized/image_38.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_38.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_39.nii.gz ./CHAOST2/niis/T2SPIR/label_39.nii.gz\n", + "(26, 256, 256) label shape (26, 256, 256)\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "Spacing: (1.58203125, 1.58203125, 9.0) -> [0.5979701450892857, 0.5979701450892857, 9.0]\n", + "Size (254, 254, 26) -> [672. 672. 26.]\n", + "./SABS/sabs_CT_normalized/image_39.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_39.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_5.nii.gz ./CHAOST2/niis/T2SPIR/label_5.nii.gz\n", + "(30, 256, 256) label shape (30, 256, 256)\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "Spacing: (1.66015625, 1.66015625, 9.0) -> [0.627499534970238, 0.627499534970238, 9.0]\n", + "Size (254, 254, 30) -> [672. 672. 30.]\n", + "./SABS/sabs_CT_normalized/image_5.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_5.nii.gz has been saved\n", + "./CHAOST2/niis/T2SPIR/image_8.nii.gz ./CHAOST2/niis/T2SPIR/label_8.nii.gz\n", + "(32, 288, 288) label shape (32, 288, 288)\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n", + "Size (286, 286, 32) -> [672. 672. 32.]\n", + "Label values: [0 1 2 3 4]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n", + "Size (286, 286, 32) -> [672. 672. 32.]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n", + "Size (286, 286, 32) -> [672. 672. 32.]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n", + "Size (286, 286, 32) -> [672. 672. 32.]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n", + "Size (286, 286, 32) -> [672. 672. 32.]\n", + "Spacing: (1.40625, 1.40625, 8.0) -> [0.5984933035714286, 0.5984933035714286, 8.0]\n", + "Size (286, 286, 32) -> [672. 672. 32.]\n", + "./SABS/sabs_CT_normalized/image_8.nii.gz has been saved\n", + "./SABS/sabs_CT_normalized/label_8.nii.gz has been saved\n" + ] + } + ], + "source": [ + "# SPA_FAC = (256 - 2 * BD_BIAS) / 512 # spacing factor\n", + "BD_BIAS = 1\n", + "scan_dir = OUT_FOLDER\n", + "for res in (256, 672):\n", + " if res == 672:\n", + " scan_dir += \"_672\"\n", + " resample_imgs(imgs, segs, pids, scan_dir,\n", + " BD_BIAS, SPA_FAC=None, required_res=res)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MRI Classmap Generation" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pid 1 finished!\n", + "pid 2 finished!\n", + "pid 3 finished!\n", + "pid 5 finished!\n", + "pid 8 finished!\n", + "pid 10 finished!\n", + "pid 13 finished!\n", + "pid 15 finished!\n", + "pid 19 finished!\n", + "pid 20 finished!\n", + "pid 21 finished!\n", + "pid 22 finished!\n", + "pid 31 finished!\n", + "pid 32 finished!\n", + "pid 33 finished!\n", + "pid 34 finished!\n", + "pid 36 finished!\n", + "pid 37 finished!\n", + "pid 38 finished!\n", + "pid 39 finished!\n" + ] + } + ], + "source": [ + "IMG_BNAMES = (\"./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz\", \"./CHAOST2/chaos_MR_T2_normalized_672/image_*.nii.gz\")\n", + "SEG_NAMES = (\"./CHAOST2/chaos_MR_T2_normalized/label_*.nii.gz\", \"./CHAOST2/chaos_MR_T2_normalized_672/label_*.nii.gz\")\n", + "\n", + "for IMG_BNAME, SEG_BNAME in zip(IMG_BNAMES, SEG_NAMES):\n", + " imgs = glob.glob(IMG_BNAME)\n", + " segs = glob.glob(SEG_BNAME)\n", + " imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", + " segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", + "\n", + "\n", + " classmap = {}\n", + " LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n", + "\n", + " MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", + "\n", + " fid = os.path.join(OUT_FOLDER,f'.classmap_{MIN_TP}.json') # name of the output file. \n", + " for _lb in LABEL_NAME:\n", + " classmap[_lb] = {}\n", + " for _sid in segs:\n", + " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", + " classmap[_lb][pid] = []\n", + "\n", + " for seg in segs:\n", + " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", + " lb_vol = sitk.GetArrayFromImage(sitk.ReadImage(seg))\n", + " n_slice = lb_vol.shape[0]\n", + " for slc in range(n_slice):\n", + " for cls in range(len(LABEL_NAME)):\n", + " if cls in lb_vol[slc, ...]:\n", + " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", + " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", + " print(f'pid {str(pid)} finished!')\n", + " \n", + " with open(fid, 'w') as fopen:\n", + " json.dump(classmap, fopen)\n", + " fopen.close() \n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Psuedo label generation for Encoder Finetuning" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import copy\n", + "import skimage\n", + "\n", + "from skimage.segmentation import slic\n", + "from skimage.segmentation import mark_boundaries\n", + "from skimage.util import img_as_float\n", + "from skimage.measure import label \n", + "import scipy.ndimage.morphology as snm\n", + "from skimage import io\n", + "import argparse\n", + "\n", + "\n", + "to01 = lambda x: (x - x.min()) / (x.max() - x.min())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Summary\n", + "\n", + "a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background\n", + "\n", + "b. Generate superpixels as pseudolabels\n", + "\n", + "Configurations of pseudlabels\n", + "\n", + "default setting of minimum superpixel sizes\n", + "`segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)`\n", + "\n", + "you can also try other configs\n", + "`segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)`" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting\n", + "\n", + "# wrapper for process 3d image in 2d\n", + "def superpix_vol(img, method = 'fezlen', **kwargs):\n", + " \"\"\"\n", + " loop through the entire volume\n", + " assuming image with axis z, x, y\n", + " \"\"\"\n", + " if method =='fezlen':\n", + " seg_func = skimage.segmentation.felzenszwalb\n", + " else:\n", + " raise NotImplementedError\n", + " \n", + " out_vol = np.zeros(img.shape)\n", + " for ii in range(img.shape[0]):\n", + " if MODE == 'MIDDLE':\n", + " segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)\n", + " else:\n", + " raise NotImplementedError\n", + " out_vol[ii, ...] = segs\n", + " \n", + " return out_vol\n", + "\n", + "# thresholding the intensity values to get a binary mask of the patient\n", + "def fg_mask2d(img_2d, thresh): # change this by your need\n", + " mask_map = np.float32(img_2d > thresh)\n", + " \n", + " def getLargestCC(segmentation): # largest connected components\n", + " labels = label(segmentation)\n", + " assert( labels.max() != 0 ) # assume at least 1 CC\n", + " largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1\n", + " return largestCC\n", + " if mask_map.max() < 0.999:\n", + " return mask_map\n", + " else:\n", + " post_mask = getLargestCC(mask_map)\n", + " fill_mask = snm.binary_fill_holes(post_mask)\n", + " return fill_mask\n", + "\n", + "# remove superpixels within the empty regions\n", + "def superpix_masking(raw_seg2d, mask2d):\n", + " raw_seg2d = np.int32(raw_seg2d)\n", + " lbvs = np.unique(raw_seg2d)\n", + " max_lb = lbvs.max()\n", + " raw_seg2d[raw_seg2d == 0] = max_lb + 1\n", + " lbvs = list(lbvs)\n", + " lbvs.append( max_lb )\n", + " raw_seg2d = raw_seg2d * mask2d\n", + " lb_new = 1\n", + " out_seg2d = np.zeros(raw_seg2d.shape)\n", + " for lbv in lbvs:\n", + " if lbv == 0:\n", + " continue\n", + " else:\n", + " out_seg2d[raw_seg2d == lbv] = lb_new\n", + " lb_new += 1\n", + " \n", + " return out_seg2d\n", + " \n", + "def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):\n", + " raw_seg = superpix_vol(img)\n", + " fg_mask_vol = np.zeros(raw_seg.shape)\n", + " processed_seg_vol = np.zeros(raw_seg.shape)\n", + " for ii in range(raw_seg.shape[0]):\n", + " if verbose:\n", + " print(\"doing {} slice\".format(ii))\n", + " _fgm = fg_mask2d(img[ii, ...], fg_thresh )\n", + " _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)\n", + " fg_mask_vol[ii] = _fgm\n", + " processed_seg_vol[ii] = _out_seg\n", + " return fg_mask_vol, processed_seg_vol\n", + " \n", + "# copy spacing and orientation info between sitk objects\n", + "def copy_info(src, dst):\n", + " dst.SetSpacing(src.GetSpacing())\n", + " dst.SetOrigin(src.GetOrigin())\n", + " dst.SetDirection(src.GetDirection())\n", + " # dst.CopyInfomation(src)\n", + " return dst\n", + "\n", + "\n", + "def strip_(img, lb):\n", + " img = np.int32(img)\n", + " if isinstance(lb, float):\n", + " lb = int(lb)\n", + " return np.float32(img == lb) * float(lb)\n", + " elif isinstance(lb, list):\n", + " out = np.zeros(img.shape)\n", + " for _lb in lb:\n", + " out += np.float32(img == int(_lb)) * float(_lb)\n", + " \n", + " return out\n", + " else:\n", + " raise Exception" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATASET_CONFIG = {'SABS':{\n", + " 'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',\n", + " 'out_dir': './SABS/sabs_CT_normalized',\n", + " 'fg_thresh': 1e-4\n", + " },\n", + " 'CHAOST2':{\n", + " 'img_bname': f'./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',\n", + " 'out_dir': './CHAOST2/chaos_MR_T2_normalized',\n", + " 'fg_thresh': 1e-4 + 50\n", + " },\n", + " 'SABS_672':{\n", + " 'img_bname': f'./SABS/sabs_CT_normalized_672/image_*.nii.gz',\n", + " 'out_dir': './SABS/sabs_CT_normalized_672',\n", + " 'fg_thresh': 1e-4\n", + " },\n", + " 'CHAOST2_672':{\n", + " 'img_bname': f'./CHAOST2/chaos_MR_T2_normalized_672/image_*.nii.gz',\n", + " 'out_dir': './CHAOST2/chaos_MR_T2_normalized_672',\n", + " 'fg_thresh': 1e-4 + 50\n", + " }\n", + "}\n", + "\n", + "for DOMAIN in DATASET_CONFIG.keys():\n", + " img_bname = DATASET_CONFIG[DOMAIN]['img_bname']\n", + " imgs = glob.glob(img_bname)\n", + " out_dir = DATASET_CONFIG[DOMAIN]['out_dir']\n", + "\n", + " imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )\n", + " print(imgs)\n", + "\n", + " # Generate pseudolabels for every image and save them\n", + " for img_fid in imgs:\n", + " # img_fid = imgs[0]\n", + "\n", + " idx = os.path.basename(img_fid).split(\"_\")[-1].split(\".nii.gz\")[0]\n", + " im_obj = sitk.ReadImage(img_fid)\n", + "\n", + " out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )\n", + " out_fg_o = sitk.GetImageFromArray(out_fg ) \n", + " out_seg_o = sitk.GetImageFromArray(out_seg )\n", + "\n", + " out_fg_o = copy_info(im_obj, out_fg_o)\n", + " out_seg_o = copy_info(im_obj, out_seg_o)\n", + " seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')\n", + " msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')\n", + " sitk.WriteImage(out_fg_o, msk_fid)\n", + " sitk.WriteImage(out_seg_o, seg_fid)\n", + " print(f'image with id {idx} has finished')\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dataloaders/GenericSuperDatasetv2.py b/dataloaders/GenericSuperDatasetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf04917ecb778cf49a88af71d459e35891646ae --- /dev/null +++ b/dataloaders/GenericSuperDatasetv2.py @@ -0,0 +1,445 @@ +""" +Dataset for training with pseudolabels +TODO: +1. Merge with manual annotated dataset +2. superpixel_scale -> superpix_config, feed like a dict +""" +import glob +import numpy as np +import dataloaders.augutils as myaug +import torch +import random +import os +import copy +import platform +import json +import re +import cv2 +from dataloaders.common import BaseDataset, Subset +from dataloaders.dataset_utils import* +from pdb import set_trace +from util.utils import CircularList +from util.consts import IMG_SIZE + +class SuperpixelDataset(BaseDataset): + def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs): + """ + Pseudolabel dataset + Args: + which_dataset: name of the dataset to use + base_dir: directory of dataset + idx_split: index of data split as we will do cross validation + mode: 'train', 'val'. + nsup: number of scans used as support. currently idle for superpixel dataset + transforms: data transform (augmentation) function + scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time + num_rep: Number of augmentation applied for a same pseudolabel + tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images + fix_length: fix the length of dataset + exclude_list: Labels to be excluded + superpix_scale: config of superpixels + """ + super(SuperpixelDataset, self).__init__(base_dir) + + self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] + self.sep = DATASET_INFO[which_dataset]['_SEP'] + self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME'] + self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] + + self.image_size = image_size + self.transforms = transforms + self.is_train = True if mode == 'train' else False + self.supervised_train = supervised_train + if self.supervised_train and len(train_list) == 0: + raise Exception('Please provide training labels') + # assert mode == 'train' + self.fix_length = fix_length + if self.supervised_train: + # self.nclass = len(self.real_label_name) + self.nclass = len(self.pseu_label_name) + else: + self.nclass = len(self.pseu_label_name) + self.num_rep = num_rep + self.tile_z_dim = tile_z_dim + self.use_3_slices = use_3_slices + if tile_z_dim > 1 and self.use_3_slices: + raise Exception("tile_z_dim and use_3_slices shouldn't be used together") + + # find scans in the data folder + self.nsup = nsup + self.base_dir = base_dir + self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ] + self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) + + # experiment configs + self.exclude_lbs = exclude_list + self.train_list = train_list + self.superpix_scale = superpix_scale + if len(exclude_list) > 0: + print(f'###### Dataset: the following classes has been excluded {exclude_list}######') + self.idx_split = idx_split + self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold + self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) + self.scan_per_load = scan_per_load + + self.info_by_scan = None + self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold + self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std) + + if self.is_train: + if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch + self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) + else: # load the entire set without a buffer + self.pid_curr_load = self.scan_ids + elif mode == 'val': + self.pid_curr_load = self.scan_ids + else: + raise Exception + + self.use_clahe = False + if kwargs['use_clahe']: + self.use_clahe = True + clip_limit = 4.0 if self.img_modality == 'MR' else 2.0 + self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7)) + + self.actual_dataset = self.read_dataset() + self.size = len(self.actual_dataset) + self.overall_slice_by_cls = self.read_classfiles() + + print("###### Initial scans loaded: ######") + print(self.pid_curr_load) + + def get_scanids(self, mode, idx_split): + """ + Load scans by train-test split + leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one + Args: + idx_split: index for spliting cross-validation folds + """ + val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) + if mode == 'train': + return [ ii for ii in self.img_pids if ii not in val_ids ] + elif mode == 'val': + return val_ids + + def reload_buffer(self): + """ + Reload a only portion of the entire dataset, if the dataset is too large + 1. delete original buffer + 2. update self.ids_this_batch + 3. update other internel variables like __len__ + """ + if self.scan_per_load <= 0: + print("We are not using the reload buffer, doing notiong") + return -1 + + del self.actual_dataset + del self.info_by_scan + + self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) + self.actual_dataset = self.read_dataset() + self.size = len(self.actual_dataset) + self.update_subclass_lookup() + print(f'Loader buffer reloaded with a new size of {self.size} slices') + + def organize_sample_fids(self): + out_list = {} + for curr_id in self.scan_ids: + curr_dict = {} + + _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') + _lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz') + _gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') + + curr_dict["img_fid"] = _img_fid + curr_dict["lbs_fid"] = _lb_fid + curr_dict["gt_lbs_fid"] = _gt_lb_fid + out_list[str(curr_id)] = curr_dict + return out_list + + def read_dataset(self): + """ + Read images into memory and store them in 2D + Build tables for the position of an individual 2D slice in the entire dataset + """ + out_list = [] + self.scan_z_idx = {} + self.info_by_scan = {} # meta data of each scan + glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset + + for scan_id, itm in self.img_lb_fids.items(): + if scan_id not in self.pid_curr_load: + continue + + img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out + # read connected graph of labels + if self.use_clahe: + # img = nself.clahe.apply(img.astype(np.uint8)) + if self.img_modality == 'MR': + img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0) + img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0) + + img = img.transpose(1,2,0) + self.info_by_scan[scan_id] = _info + + img = np.float32(img) + img = self.norm_func(img) + self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] + + if self.supervised_train: + lb = read_nii_bysitk(itm["gt_lbs_fid"]) + else: + lb = read_nii_bysitk(itm["lbs_fid"]) + lb = lb.transpose(1,2,0) + lb = np.int32(lb) + + # resize img and lb to self.image_size + img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) + lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) + + # format of slices: [axial_H x axial_W x Z] + if self.supervised_train: + # remove all images that dont have the training labels + del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))] + # create an new img and lb without indices in del_indices + new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices))) + new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices))) + new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)] + new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)] + + img = new_img + lb = new_lb + a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0] + + nframes = img.shape[-1] + assert img.shape[-1] == lb.shape[-1] + base_idx = img.shape[-1] // 2 # index of the middle slice + + # re-organize 3D images into 2D slices and record essential information for each slice + out_list.append( {"img": img[..., 0: 1], + "lb":lb[..., 0: 0 + 1], + "sup_max_cls": lb[..., 0: 0 + 1].max(), + "is_start": True, + "is_end": False, + "nframe": nframes, + "scan_id": scan_id, + "z_id":0, + }) + + self.scan_z_idx[scan_id][0] = glb_idx + glb_idx += 1 + + for ii in range(1, img.shape[-1] - 1): + out_list.append( {"img": img[..., ii: ii + 1], + "lb":lb[..., ii: ii + 1], + "is_start": False, + "is_end": False, + "sup_max_cls": lb[..., ii: ii + 1].max(), + "nframe": nframes, + "scan_id": scan_id, + "z_id": ii, + }) + self.scan_z_idx[scan_id][ii] = glb_idx + glb_idx += 1 + + ii += 1 # last slice of a 3D volume + out_list.append( {"img": img[..., ii: ii + 1], + "lb":lb[..., ii: ii+ 1], + "is_start": False, + "is_end": True, + "sup_max_cls": lb[..., ii: ii + 1].max(), + "nframe": nframes, + "scan_id": scan_id, + "z_id": ii, + }) + + self.scan_z_idx[scan_id][ii] = glb_idx + glb_idx += 1 + + return out_list + + def read_classfiles(self): + """ + Load the scan-slice-class indexing file + """ + with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen: + cls_map = json.load( fopen) + fopen.close() + + with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen: + self.tp1_cls_map = json.load( fopen) + fopen.close() + + return cls_map + + def get_superpixels_similarity(self, sp1, sp2): + pass + + def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None): + if bi_val is None: + # bi_val = np.random.randint(1, sup_max_cls) + bi_val = random.choice(list(np.unique(super_map))) + if conn_graph is not None and img is not None: + # get number of neighbors of bi_val + neighbors = conn_graph[bi_val] + # pick a random number of neighbors and merge them + n_neighbors = np.random.randint(0, len(neighbors)) + try: + neighbors = random.sample(neighbors, n_neighbors) + except TypeError: + neighbors = [] + # merge neighbors + super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map) + return np.float32(super_map == bi_val) + + def supcls_pick(self, super_map): + return random.choice(list(np.unique(super_map))) + + def get_3_slice_adjacent_image(self, image_t, index): + curr_dict = self.actual_dataset[index] + prev_image = np.zeros_like(image_t) + + if index > 1 and not curr_dict["is_start"]: + prev_dict = self.actual_dataset[index - 1] + prev_image = prev_dict["img"] + + next_image = np.zeros_like(image_t) + if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: + next_dict = self.actual_dataset[index + 1] + next_image = next_dict["img"] + + image_t = np.concatenate([prev_image, image_t, next_image], axis=-1) + + return image_t + + def __getitem__(self, index): + index = index % len(self.actual_dataset) + curr_dict = self.actual_dataset[index] + sup_max_cls = curr_dict['sup_max_cls'] + if sup_max_cls < 1: + return self.__getitem__(index + 1) + + image_t = curr_dict["img"] + label_raw = curr_dict["lb"] + + if self.use_3_slices: + image_t = self.get_3_slice_adjacent_image(image_t, index) + + for _ex_cls in self.exclude_lbs: + if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen + return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) + + if self.supervised_train: + superpix_label = -1 + label_t = np.float32(label_raw) + + lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list))) + label_t[label_t != lb_id] = 0 + label_t[label_t == lb_id] = 1 + + else: + superpix_label = self.supcls_pick(label_raw) + label_t = np.float32(label_raw == superpix_label) + + pair_buffer = [] + + comp = np.concatenate( [image_t, label_t], axis = -1 ) + + for ii in range(self.num_rep): + if self.transforms is not None: + img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False) + else: + img, lb = comp[:, :, 0:1], comp[:, :, 1:2] + # if ii % 2 == 0: + # label_raw = lb + # lb = lb == superpix_label + + img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float() + lb = torch.from_numpy( lb.squeeze(-1)).float() + + img = img.repeat( [ self.tile_z_dim, 1, 1] ) + + is_start = curr_dict["is_start"] + is_end = curr_dict["is_end"] + nframe = np.int32(curr_dict["nframe"]) + scan_id = curr_dict["scan_id"] + z_id = curr_dict["z_id"] + + sample = {"image": img, + "label":lb, + "is_start": is_start, + "is_end": is_end, + "nframe": nframe, + "scan_id": scan_id, + "z_id": z_id + } + + # Add auxiliary attributes + if self.aux_attrib is not None: + for key_prefix in self.aux_attrib: + # Process the data sample, create new attributes and save them in a dictionary + aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) + for key_suffix in aux_attrib_val: + # one function may create multiple attributes, so we need suffix to distinguish them + sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] + pair_buffer.append(sample) + + support_images = [] + support_mask = [] + support_class = [] + + query_images = [] + query_labels = [] + query_class = [] + + for idx, itm in enumerate(pair_buffer): + if idx % 2 == 0: + support_images.append(itm["image"]) + support_class.append(1) # pseudolabel class + support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] )) + else: + query_images.append(itm["image"]) + query_class.append(1) + query_labels.append( itm["label"]) + + return {'class_ids': [support_class], + 'support_images': [support_images], # + 'superpix_label': superpix_label, + 'superpix_label_raw': label_raw[:,:,0], + 'support_mask': [support_mask], + 'query_images': query_images, # + 'query_labels': query_labels, + 'scan_id': scan_id, + 'z_id': z_id, + 'nframe': nframe, + } + + + def __len__(self): + """ + copy-paste from basic naive dataset configuration + """ + if self.fix_length != None: + assert self.fix_length >= len(self.actual_dataset) + return self.fix_length + else: + return len(self.actual_dataset) + + def getMaskMedImg(self, label, class_id, class_ids): + """ + Generate FG/BG mask from the segmentation mask + + Args: + label: semantic mask + class_id: semantic class of interest + class_ids: all class id in this episode + """ + fg_mask = torch.where(label == class_id, + torch.ones_like(label), torch.zeros_like(label)) + bg_mask = torch.where(label != class_id, + torch.ones_like(label), torch.zeros_like(label)) + for class_id in class_ids: + bg_mask[label == class_id] = 0 + + return {'fg_mask': fg_mask, + 'bg_mask': bg_mask} diff --git a/dataloaders/ManualAnnoDatasetv2.py b/dataloaders/ManualAnnoDatasetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..d8312d20514c5177d03e41ec05b3c8b8c89cf6e0 --- /dev/null +++ b/dataloaders/ManualAnnoDatasetv2.py @@ -0,0 +1,756 @@ +""" +Manually labeled dataset +TODO: +1. Merge with superpixel dataset +""" +import glob +import numpy as np +import dataloaders.augutils as myaug +import torch +import random +import os +import copy +import platform +import json +import re +import cv2 +from dataloaders.common import BaseDataset, Subset, ValidationDataset +# from common import BaseDataset, Subset +from dataloaders.dataset_utils import* +from pdb import set_trace +from util.utils import CircularList +from util.consts import IMG_SIZE + +MODE_DEFAULT = "default" +MODE_FULL_SCAN = "full_scan" + +class ManualAnnoDataset(BaseDataset): + def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None, **kwargs): + """ + Manually labeled dataset + Args: + which_dataset: name of the dataset to use + base_dir: directory of dataset + idx_split: index of data split as we will do cross validation + mode: 'train', 'val'. + transforms: data transform (augmentation) function + min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset + scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time + tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images + nsup: number of support scans + fix_length: fix the length of dataset + exclude_list: Labels to be excluded + extern_normalize_function: normalization function used for data pre-processing + """ + super(ManualAnnoDataset, self).__init__(base_dir) + self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] + self.sep = DATASET_INFO[which_dataset]['_SEP'] + self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] + self.image_size = image_size + self.transforms = transforms + self.is_train = True if mode == 'train' else False + self.phase = mode + self.fix_length = fix_length + self.all_label_names = self.label_name + self.nclass = len(self.label_name) + self.tile_z_dim = tile_z_dim + self.base_dir = base_dir + self.nsup = nsup + self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ] + self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds + if 'use_clahe' not in kwargs: + self.use_clahe = False + else: + self.use_clahe = kwargs['use_clahe'] + if self.use_clahe: + self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7,7)) + + self.use_3_slices = kwargs["use_3_slices"] if 'use_3_slices' in kwargs else False + if self.use_3_slices: + self.tile_z_dim=1 + + self.get_item_mode = MODE_DEFAULT + if 'get_item_mode' in kwargs: + self.get_item_mode = kwargs['get_item_mode'] + + self.exclude_lbs = exclude_list + if len(exclude_list) > 0: + print(f'###### Dataset: the following classes has been excluded {exclude_list}######') + + self.idx_split = idx_split + self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold + self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) + + self.scan_per_load = scan_per_load + + self.info_by_scan = None + self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold + + if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset. + self.norm_func = extern_normalize_func + print(f'###### Dataset: using external normalization statistics ######') + else: + self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) + print(f'###### Dataset: using normalization statistics calculated from loaded data ######') + + if self.is_train: + if scan_per_load > 0: # buffer needed + self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) + else: # load the entire set without a buffer + self.pid_curr_load = self.scan_ids + elif mode == 'val': + self.pid_curr_load = self.scan_ids + self.potential_support_sid = [] + else: + raise Exception + self.actual_dataset = self.read_dataset() + self.size = len(self.actual_dataset) + self.overall_slice_by_cls = self.read_classfiles() + self.update_subclass_lookup() + + def get_scanids(self, mode, idx_split): + val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) + self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index + if mode == 'train': + return [ ii for ii in self.img_pids if ii not in val_ids ] + elif mode == 'val': + return val_ids + + def reload_buffer(self): + """ + Reload a portion of the entire dataset, if the dataset is too large + 1. delete original buffer + 2. update self.ids_this_batch + 3. update other internel variables like __len__ + """ + if self.scan_per_load <= 0: + print("We are not using the reload buffer, doing notiong") + return -1 + + del self.actual_dataset + del self.info_by_scan + self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) + self.actual_dataset = self.read_dataset() + self.size = len(self.actual_dataset) + self.update_subclass_lookup() + print(f'Loader buffer reloaded with a new size of {self.size} slices') + + def organize_sample_fids(self): + out_list = {} + for curr_id in self.scan_ids: + curr_dict = {} + + _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') + _lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') + + curr_dict["img_fid"] = _img_fid + curr_dict["lbs_fid"] = _lb_fid + out_list[str(curr_id)] = curr_dict + return out_list + + def read_dataset(self): + """ + Build index pointers to individual slices + Also keep a look-up table from scan_id, slice to index + """ + out_list = [] + self.scan_z_idx = {} + self.info_by_scan = {} # meta data of each scan + glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset + + for scan_id, itm in self.img_lb_fids.items(): + if scan_id not in self.pid_curr_load: + continue + + img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out + + img = img.transpose(1,2,0) + + self.info_by_scan[scan_id] = _info + + if self.use_clahe: + img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0) + + img = np.float32(img) + img = self.norm_func(img) + + self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] + + lb = read_nii_bysitk(itm["lbs_fid"]) + lb = lb.transpose(1,2,0) + + lb = np.float32(lb) + + img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) + lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) + + assert img.shape[-1] == lb.shape[-1] + base_idx = img.shape[-1] // 2 # index of the middle slice + + # write the beginning frame + out_list.append( {"img": img[..., 0: 1], + "lb":lb[..., 0: 0 + 1], + "is_start": True, + "is_end": False, + "nframe": img.shape[-1], + "scan_id": scan_id, + "z_id":0}) + + self.scan_z_idx[scan_id][0] = glb_idx + glb_idx += 1 + + for ii in range(1, img.shape[-1] - 1): + out_list.append( {"img": img[..., ii: ii + 1], + "lb":lb[..., ii: ii + 1], + "is_start": False, + "is_end": False, + "nframe": -1, + "scan_id": scan_id, + "z_id": ii + }) + self.scan_z_idx[scan_id][ii] = glb_idx + glb_idx += 1 + + ii += 1 # last frame, note the is_end flag + out_list.append( {"img": img[..., ii: ii + 1], + "lb":lb[..., ii: ii+ 1], + "is_start": False, + "is_end": True, + "nframe": -1, + "scan_id": scan_id, + "z_id": ii + }) + + self.scan_z_idx[scan_id][ii] = glb_idx + glb_idx += 1 + + return out_list + + def read_classfiles(self): + with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen: + cls_map = json.load( fopen) + fopen.close() + + with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen: + self.tp1_cls_map = json.load( fopen) + fopen.close() + + return cls_map + + def __getitem__(self, index): + if self.get_item_mode == MODE_DEFAULT: + return self.__getitem_default__(index) + elif self.get_item_mode == MODE_FULL_SCAN: + return self.__get_ct_scan___(index) + else: + raise NotImplementedError("Unknown mode") + + + def __get_ct_scan___(self, index): + scan_n = index % len(self.scan_z_idx) + scan_id = list(self.scan_z_idx.keys())[scan_n] + scan_slices = self.scan_z_idx[scan_id] + + scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) + + scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) + + scan_imgs = np.float32(scan_imgs) + scan_lbs = np.float32(scan_lbs) + + scan_imgs = torch.from_numpy(scan_imgs).unsqueeze(0) + scan_lbs = torch.from_numpy(scan_lbs) + + if self.tile_z_dim: + scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1) + assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}' + + # # reshape to C, D, H, W + # scan_imgs = scan_imgs.permute(1, 0, 2, 3) + # scan_lbs = scan_lbs.permute(1, 0, 2, 3) + + sample = {"image": scan_imgs, + "label":scan_lbs, + "scan_id": scan_id, + } + + return sample + + + def get_3_slice_adjacent_image(self, image_t, index): + curr_dict = self.actual_dataset[index] + prev_image = np.zeros_like(image_t) + + if index > 1 and not curr_dict["is_start"]: + prev_dict = self.actual_dataset[index - 1] + prev_image = prev_dict["img"] + + next_image = np.zeros_like(image_t) + if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: + next_dict = self.actual_dataset[index + 1] + next_image = next_dict["img"] + + image_t = np.concatenate([prev_image, image_t, next_image], axis=-1) + + return image_t + + + def __getitem_default__(self, index): + index = index % len(self.actual_dataset) + curr_dict = self.actual_dataset[index] + if self.is_train: + if len(self.exclude_lbs) > 0: + for _ex_cls in self.exclude_lbs: + if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen + return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) + + comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 ) + if self.transforms is not None: + img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False) + else: + raise Exception("No transform function is provided") + + else: + img = curr_dict['img'] + lb = curr_dict['lb'] + + + img = np.float32(img) + lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure + if self.use_3_slices: + img = self.get_3_slice_adjacent_image(img, index) + + img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) + lb = torch.from_numpy( lb) + + if self.tile_z_dim: + img = img.repeat( [ self.tile_z_dim, 1, 1] ) + assert img.ndimension() == 3, f'actual dim {img.ndimension()}' + + is_start = curr_dict["is_start"] + is_end = curr_dict["is_end"] + nframe = np.int32(curr_dict["nframe"]) + scan_id = curr_dict["scan_id"] + z_id = curr_dict["z_id"] + + sample = {"image": img, + "label":lb, + "is_start": is_start, + "is_end": is_end, + "nframe": nframe, + "scan_id": scan_id, + "z_id": z_id + } + # Add auxiliary attributes + if self.aux_attrib is not None: + for key_prefix in self.aux_attrib: + # Process the data sample, create new attributes and save them in a dictionary + aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) + for key_suffix in aux_attrib_val: + # one function may create multiple attributes, so we need suffix to distinguish them + sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] + + return sample + + def __len__(self): + """ + copy-paste from basic naive dataset configuration + """ + if self.get_item_mode == MODE_FULL_SCAN: + return len(self.scan_z_idx) + + if self.fix_length != None: + assert self.fix_length >= len(self.actual_dataset) + return self.fix_length + else: + return len(self.actual_dataset) + + def update_subclass_lookup(self): + """ + Updating the class-slice indexing list + Args: + [internal] overall_slice_by_cls: + { + class1: {pid1: [slice1, slice2, ....], + pid2: [slice1, slice2]}, + ...} + class2: + ... + } + out[internal]: + { + class1: [ idx1, idx2, ... ], + class2: [ idx1, idx2, ... ], + ... + } + + """ + # delete previous ones if any + assert self.overall_slice_by_cls is not None + + if not hasattr(self, 'idx_by_class'): + self.idx_by_class = {} + # filter the new one given the actual list + for cls in self.label_name: + if cls not in self.idx_by_class.keys(): + self.idx_by_class[cls] = [] + else: + del self.idx_by_class[cls][:] + for cls, dict_by_pid in self.overall_slice_by_cls.items(): + for pid, slice_list in dict_by_pid.items(): + if pid not in self.pid_curr_load: + continue + self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ] + print("###### index-by-class table has been reloaded ######") + + def getMaskMedImg(self, label, class_id, class_ids): + """ + Generate FG/BG mask from the segmentation mask. Used when getting the support + """ + # Dense Mask + fg_mask = torch.where(label == class_id, + torch.ones_like(label), torch.zeros_like(label)) + bg_mask = torch.where(label != class_id, + torch.ones_like(label), torch.zeros_like(label)) + for class_id in class_ids: + bg_mask[label == class_id] = 0 + + return {'fg_mask': fg_mask, + 'bg_mask': bg_mask} + + def subsets(self, sub_args_lst=None): + """ + Override base-class subset method + Create subsets by scan_ids + + output: list [[] , ] + """ + + if sub_args_lst is not None: + subsets = [] + ii = 0 + for cls_name, index_list in self.idx_by_class.items(): + subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) ) + ii += 1 + else: + subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()] + return subsets + + def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int): + """ + getting (probably multi-shot) support set for evaluation + sample from 50% (1shot) or 20 35 50 65 80 (5shot) + Args: + curr_cls: current class to segment, starts from 1 + class_idx: a list of all foreground class in nways, starts from 1 + npart: how may chunks used to split the support + scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan + being served as support, in self.pid_curr_load + """ + assert npart % 2 == 1 + assert curr_class != 0; assert 0 not in class_idx + # assert not self.is_train + + self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] + # print(f'###### Using {len(scan_idx)} shot evaluation!') + + if npart == 1: + pcts = [0.5] + else: + half_part = 1 / (npart * 2) + part_interval = (1.0 - 1.0 / npart) / (npart - 1) + pcts = [ half_part + part_interval * ii for ii in range(npart) ] + + # print(f'###### Parts percentage: {pcts} ######') + + # norm_func = get_normalize_op(modality='MR', fids=None) + out_buffer = [] # [{scanid, img, lb}] + for _part in range(npart): + concat_buffer = [] # for each fold do a concat in image and mask in batch dimension + for scan_order in scan_idx: + _scan_id = self.pid_curr_load[ scan_order ] + print(f'Using scan {_scan_id} as support!') + + # for _pc in pcts: + _zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices + _zid = _zlist[int(pcts[_part] * len(_zlist))] + _glb_idx = self.scan_z_idx[_scan_id][_zid] + + # almost copy-paste __getitem__ but no augmentation + curr_dict = self.actual_dataset[_glb_idx] + img = curr_dict['img'] + lb = curr_dict['lb'] + + if self.use_3_slices: + prev_image = np.zeros_like(img) + if _glb_idx > 1 and not curr_dict["is_start"]: + prev_dict = self.actual_dataset[_glb_idx - 1] + prev_image = prev_dict["img"] + + next_image = np.zeros_like(img) + if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: + next_dict = self.actual_dataset[_glb_idx + 1] + next_image = next_dict["img"] + + img = np.concatenate([prev_image, img, next_image], axis=-1) + + img = np.float32(img) + lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure + + img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) + lb = torch.from_numpy( lb ) + + if self.tile_z_dim: + img = img.repeat( [ self.tile_z_dim, 1, 1] ) + assert img.ndimension() == 3, f'actual dim {img.ndimension()}' + + is_start = curr_dict["is_start"] + is_end = curr_dict["is_end"] + nframe = np.int32(curr_dict["nframe"]) + scan_id = curr_dict["scan_id"] + z_id = curr_dict["z_id"] + + sample = {"image": img, + "label":lb, + "is_start": is_start, + "inst": None, + "scribble": None, + "is_end": is_end, + "nframe": nframe, + "scan_id": scan_id, + "z_id": z_id + } + + concat_buffer.append(sample) + out_buffer.append({ + "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), + "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), + + }) + + # do the concat, and add to output_buffer + + # post-processing, including keeping the foreground and suppressing background. + support_images = [] + support_mask = [] + support_class = [] + for itm in out_buffer: + support_images.append(itm["image"]) + support_class.append(curr_class) + support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) + + return {'class_ids': [support_class], + 'support_images': [support_images], # + 'support_mask': [support_mask], + } + + def get_support_scan(self, curr_class: int, class_idx: list, scan_idx: list): + self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] + # print(f'###### Using {len(scan_idx)} shot evaluation!') + scan_slices = self.scan_z_idx[self.potential_support_sid[0]] + scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) + + scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) + # binarize the labels + scan_lbs[scan_lbs != curr_class] = 0 + scan_lbs[scan_lbs == curr_class] = 1 + + scan_imgs = torch.from_numpy(np.float32(scan_imgs)).unsqueeze(0) + scan_lbs = torch.from_numpy(np.float32(scan_lbs)) + + if self.tile_z_dim: + scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1) + assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}' + + # reshape to C, D, H, W + sample = {"scan": scan_imgs, + "labels":scan_lbs, + } + + return sample + + + def get_support_multiple_classes(self, classes: list, scan_idx: list, npart: int, use_3_slices=False): + """ + getting (probably multi-shot) support set for evaluation + sample from 50% (1shot) or 20 35 50 65 80 (5shot) + Args: + curr_cls: current class to segment, starts from 1 + class_idx: a list of all foreground class in nways, starts from 1 + npart: how may chunks used to split the support + scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan + being served as support, in self.pid_curr_load + """ + assert npart % 2 == 1 + # assert curr_class != 0; assert 0 not in class_idx + # assert not self.is_train + + self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] + # print(f'###### Using {len(scan_idx)} shot evaluation!') + + if npart == 1: + pcts = [0.5] + else: + half_part = 1 / (npart * 2) + part_interval = (1.0 - 1.0 / npart) / (npart - 1) + pcts = [ half_part + part_interval * ii for ii in range(npart) ] + + # print(f'###### Parts percentage: {pcts} ######') + + out_buffer = [] # [{scanid, img, lb}] + for _part in range(npart): + concat_buffer = [] # for each fold do a concat in image and mask in batch dimension + for scan_order in scan_idx: + _scan_id = self.pid_curr_load[ scan_order ] + print(f'Using scan {_scan_id} as support!') + + # for _pc in pcts: + zlist = [] + for curr_class in classes: + zlist.append(self.tp1_cls_map[self.label_name[curr_class]][_scan_id]) # list of indices + # merge all the lists in zlist and keep only the unique elements + # _zlist = sorted(list(set([item for sublist in zlist for item in sublist]))) + # take only the indices that appear in all of the sublist + _zlist = sorted(list(set.intersection(*map(set, zlist)))) + _zid = _zlist[int(pcts[_part] * len(_zlist))] + _glb_idx = self.scan_z_idx[_scan_id][_zid] + + # almost copy-paste __getitem__ but no augmentation + curr_dict = self.actual_dataset[_glb_idx] + img = curr_dict['img'] + lb = curr_dict['lb'] + + if use_3_slices: + prev_image = np.zeros_like(img) + if _glb_idx > 1 and not curr_dict["is_start"]: + prev_dict = self.actual_dataset[_glb_idx - 1] + assert prev_dict["scan_id"] == curr_dict["scan_id"] + assert prev_dict["z_id"] == curr_dict["z_id"] - 1 + prev_image = prev_dict["img"] + + next_image = np.zeros_like(img) + if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: + next_dict = self.actual_dataset[_glb_idx + 1] + assert next_dict["scan_id"] == curr_dict["scan_id"] + assert next_dict["z_id"] == curr_dict["z_id"] + 1 + next_image = next_dict["img"] + + img = np.concatenate([prev_image, img, next_image], axis=-1) + + img = np.float32(img) + lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure + # zero all labels that are not in the classes arg + mask = np.zeros_like(lb) + for cls in classes: + mask[lb == cls] = 1 + lb[~mask.astype(np.bool)] = 0 + + img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) + lb = torch.from_numpy( lb ) + + if self.tile_z_dim: + img = img.repeat( [ self.tile_z_dim, 1, 1] ) + assert img.ndimension() == 3, f'actual dim {img.ndimension()}' + + is_start = curr_dict["is_start"] + is_end = curr_dict["is_end"] + nframe = np.int32(curr_dict["nframe"]) + scan_id = curr_dict["scan_id"] + z_id = curr_dict["z_id"] + + sample = {"image": img, + "label":lb, + "is_start": is_start, + "inst": None, + "scribble": None, + "is_end": is_end, + "nframe": nframe, + "scan_id": scan_id, + "z_id": z_id + } + + concat_buffer.append(sample) + out_buffer.append({ + "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), + "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), + + }) + + # do the concat, and add to output_buffer + + # post-processing, including keeping the foreground and suppressing background. + support_images = [] + support_mask = [] + support_class = [] + for itm in out_buffer: + support_images.append(itm["image"]) + support_class.append(curr_class) + # support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) + support_mask.append(itm["label"]) + + return {'class_ids': [support_class], + 'support_images': [support_images], # + 'support_mask': [support_mask], + 'scan_id': scan_id + } + +def get_nii_dataset(config, image_size, **kwargs): + print(f"Check config: {config}") + organ_mapping = { + "sabs":{ + "rk": 2, + "lk": 3, + "liver": 6, + "spleen": 1 + }, + "chaost2":{ + "liver": 1, + "rk": 2, + "lk": 3, + "spleen": 4 + }} + + transforms = None + data_name = config['dataset'] + if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672': + baseset_name = 'SABS' + max_label = 13 + modality="CT" + elif data_name == 'C0_Superpix': + raise NotImplementedError + baseset_name = 'C0' + max_label = 3 + elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672': + baseset_name = 'CHAOST2' + max_label = 4 + modality="MR" + elif 'lits' in data_name.lower(): + baseset_name = 'LITS17' + max_label = 4 + else: + raise ValueError(f'Dataset: {data_name} not found') + + # norm_func = get_normalize_op(modality=modality, fids=None) # TODO add global statistics + # norm_func = None + + test_label = organ_mapping[baseset_name.lower()][config["curr_cls"]] + base_dir = config['path'][data_name]['data_dir'] + testdataset = ManualAnnoDataset(which_dataset=baseset_name, + base_dir=base_dir, + idx_split = config['eval_fold'], + mode = 'val', + scan_per_load = 1, + transforms=transforms, + min_fg=1, + nsup = config["task"]["n_shots"], + fix_length=None, + image_size=image_size, + # extern_normalize_func=norm_func + **kwargs) + + testdataset = ValidationDataset(testdataset, test_classes = [test_label], npart = config["task"]["npart"]) + testdataset.set_curr_cls(test_label) + + traindataset = None # TODO make this the support set later + + return traindataset, testdataset \ No newline at end of file diff --git a/dataloaders/PolypDataset.py b/dataloaders/PolypDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb827d1618887495bf6bbbed29a202c0a085b3b --- /dev/null +++ b/dataloaders/PolypDataset.py @@ -0,0 +1,548 @@ +""" +Copied from https://github.com/talshaharabany/AutoSAM +""" + +import os +from PIL import Image +import torch.utils.data as data +import torchvision.transforms as transforms +import numpy as np +import random +import torch +from dataloaders.PolypTransforms import get_polyp_transform +import cv2 +KVASIR = "Kvasir" +CLINIC_DB = "CVC-ClinicDB" +COLON_DB = "CVC-ColonDB" +ETIS_DB = "ETIS-LaribPolypDB" +CVC300 = "CVC-300" + +DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB) +EXCLUDE_DS = (CVC300, ) + + +def create_suppport_set_for_polyps(n_support=10): + """ + create a text file contating n_support_images for each dataset + """ + root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset" + supp_images = [] + supp_masks = [] + + image_dir = os.path.join(root_dir, "images") + mask_dir = os.path.join(root_dir, "masks") + # randonly sample n_support images and masks + image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir( + image_dir) if f.endswith('.jpg') or f.endswith('.png')]) + mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir( + mask_dir) if f.endswith('.png')]) + + while len(supp_images) < n_support: + index = random.randint(0, len(image_paths) - 1) + # check that the index is not already in the support set + if image_paths[index] in supp_images: + continue + supp_images.append(image_paths[index]) + supp_masks.append(mask_paths[index]) + + with open(os.path.join(root_dir, "support.txt"), 'w') as file: + for image_path, mask_path in zip(supp_images, supp_masks): + file.write(f"{image_path} {mask_path}\n") + +def create_train_val_test_split_for_polyps(): + root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/" + # for each subdir in root_dir, create a split file + num_train_images_per_dataset = { + "CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0} + + num_test_images_per_dataset = { + "CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380} + + for subdir in os.listdir(root_dir): + subdir_path = os.path.join(root_dir, subdir) + if os.path.isdir(subdir_path): + split_file = os.path.join(subdir_path, "split.txt") + image_dir = os.path.join(subdir_path, "images") + create_train_val_test_split( + image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir]) + + +def create_train_val_test_split(root, split_file, train_number=100, test_number=20): + """ + Create a train, val, test split file for a dataset + root: root directory of dataset + split_file: name of split file to create + train_ratio: ratio of train set + val_ratio: ratio of val set + test_ratio: ratio of test set + """ + # Get all files in root directory + files = os.listdir(root) + # Filter out non-image files, remove suffix + files = [f.split('.')[0] + for f in files if f.endswith('.jpg') or f.endswith('.png')] + # Shuffle files + random.shuffle(files) + + # Calculate number of files for each split + num_files = len(files) + num_train = train_number + num_test = test_number + num_val = num_files - num_train - num_test + print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") + # Create splits + train = files[:num_train] + val = files[num_train:num_train + num_val] + test = files[num_train + num_val:] + + # Write splits to file + with open(split_file, 'w') as file: + file.write("train\n") + for f in train: + file.write(f + "\n") + file.write("val\n") + for f in val: + file.write(f + "\n") + file.write("test\n") + for f in test: + file.write(f + "\n") + + +class PolypDataset(data.Dataset): + """ + dataloader for polyp segmentation tasks + """ + + def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): + self.trainsize = trainsize + self.augmentations = augmentations + self.datasets = datasets + if isinstance(image_size, int): + image_size = (image_size, image_size) + self.image_size = image_size + if image_root is not None and gt_root is not None: + self.images = [ + os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] + self.gts = [ + os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')] + # also look in subdirectories + for subdir in os.listdir(image_root): + # if not dir, continue + if not os.path.isdir(os.path.join(image_root, subdir)): + continue + subdir_image_root = os.path.join(image_root, subdir) + subdir_gt_root = os.path.join(gt_root, subdir) + self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( + subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) + self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( + subdir_gt_root) if f.endswith('.png')]) + + else: + self.images, self.gts = self.get_image_gt_pairs( + root, split="train" if train else "test", datasets=self.datasets) + self.images = sorted(self.images) + self.gts = sorted(self.gts) + if not 'VPS' in root: + self.filter_files_and_get_ds_mean_and_std() + if ds_mean is not None and ds_std is not None: + self.mean, self.std = ds_mean, ds_std + self.size = len(self.images) + self.train = train + self.sam_trans = sam_trans + if self.sam_trans is not None: + # sam trans takes care of norm + self.mean, self.std = 0 , 1 + + def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS): + """ + for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file + dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders + split: train, val, or test + """ + image_paths = [] + gt_paths = [] + for folder in os.listdir(dir_root): + if folder not in datasets: + continue + split_file = os.path.join(dir_root, folder, "split.txt") + if os.path.isfile(split_file): + image_root = os.path.join(dir_root, folder, "images") + gt_root = os.path.join(dir_root, folder, "masks") + image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file( + image_root, gt_root, split_file, split=split) + image_paths.extend(image_paths_tmp) + gt_paths.extend(gt_paths_tmp) + else: + print( + f"No split.txt file found in {os.path.join(dir_root, folder)}") + + return image_paths, gt_paths + + def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"): + """ + image_root: root directory of images + gt_root: root directory of ground truth + text_file: text file containing train, val, test split with the following format: + train: + image1 + image2 + ... + val: + image1 + image2 + ... + test: + image1 + image2 + ... + + split: train, val, or test + """ + # Initialize a dictionary to hold file names for each split + splits = {"train": [], "val": [], "test": []} + current_split = None + + # Read the text file and categorize file names under each split + with open(text_file, 'r') as file: + for line in file: + line = line.strip() + if line in splits: + current_split = line + elif line and current_split: + splits[current_split].append(line) + + # Get the file names for the requested split + file_names = splits[split] + + # Create image-ground truth pairs + image_paths = [] + gt_paths = [] + for name in file_names: + image_path = os.path.join(image_root, name + '.png') + gt_path = os.path.join(gt_root, name + '.png') + image_paths.append(image_path) + gt_paths.append(gt_path) + + return image_paths, gt_paths + + def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1): + support_images = [] + support_labels = [] + # get all images and masks + support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir( + support_image_dir) if f.endswith('.jpg') or f.endswith('.png')]) + support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir( + support_mask_dir) if f.endswith('.png')]) + # sample n_support images and masks + for i in range(n_support): + index = random.randint(0, len(support_image_paths) - 1) + support_img = self.cv2_loader( + support_image_paths[index], is_mask=False) + support_mask = self.cv2_loader( + support_mask_paths[index], is_mask=True) + support_images.append(support_img) + support_labels.append(support_mask) + + if self.augmentations: + support_images = [self.augmentations( + img, mask)[0] for img, mask in zip(support_images, support_labels)] + support_labels = [self.augmentations( + img, mask)[1] for img, mask in zip(support_images, support_labels)] + + support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images] + + if self.sam_trans is not None: + support_images = [self.sam_trans.preprocess( + img).squeeze(0) for img in support_images] + support_labels = [self.sam_trans.preprocess( + mask) for mask in support_labels] + else: + image_size = self.image_size + support_images = [torch.nn.functional.interpolate(img.unsqueeze( + 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images] + support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( + 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels] + + return torch.stack(support_images), torch.stack(support_labels) + + def get_support_from_text_file(self, text_file, n_support=1): + """ + each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path + """ + support_images = [] + support_labels = [] + with open(text_file, 'r') as file: + for line in file: + image_path, mask_path = line.strip().split() + support_images.append(image_path) + support_labels.append(mask_path) + + # indices = random.choices(range(len(support_images)), k=n_support) + if n_support > len(support_images): + raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})") + + n_support_images = support_images[:n_support] + n_support_labels = support_labels[:n_support] + + return n_support_images, n_support_labels + + def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None): + """ + Get support set from specified directories, text file or from the dataset itself + """ + if support_image_dir is not None and support_mask_dir: + return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support) + elif text_file is not None: + support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support) + else: + # randomly sample n_support images and masks from the dataset + indices = random.choices(range(self.size), k=n_support) + # indices = list(range(n_support)) + print(f"support indices:{indices}") + support_image_paths = [self.images[index] for index in indices] + support_gt_paths = [self.gts[index] for index in indices] + + support_images = [] + support_gts = [] + + for image_path, gt_path in zip(support_image_paths, support_gt_paths): + support_img = self.cv2_loader(image_path, is_mask=False) + support_mask = self.cv2_loader(gt_path, is_mask=True) + out = self.process_image_gt(support_img, support_mask) + support_images.append(out['image'].unsqueeze(0)) + support_gts.append(out['label'].unsqueeze(0)) + if len(support_images) >= n_support: + break + return support_images, support_gts, out['case'] + # return torch.stack(support_images), torch.stack(support_gts), out['case'] + + def process_image_gt(self, image, gt, dataset=""): + """ + image and gt are expected to be output from self.cv2_loader + """ + original_size = tuple(image.shape[-2:]) + if self.augmentations: + image, mask = self.augmentations(image, gt) + + if self.sam_trans: + image, mask = self.sam_trans.apply_image_torch( + image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask) + elif image.max() <= 255 and image.min() >= 0: + image = (image - self.mean) / self.std + mask[mask > 0.5] = 1 + mask[mask <= 0.5] = 0 + # image_size = tuple(img.shape[-2:]) + + image_size = self.image_size + if self.sam_trans is None: + image = torch.nn.functional.interpolate(image.unsqueeze( + 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) + mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( + 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) + # img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse + + return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image, + 'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask, + 'original_size': torch.Tensor(original_size), + 'image_size': torch.Tensor(image_size), + 'case': dataset} # case to be compatible with polyp video dataset + + def get_dataset_name_from_path(self, path): + for dataset in self.datasets: + if dataset in path: + return dataset + return "" + + def __getitem__(self, index): + image = self.cv2_loader(self.images[index], is_mask=False) + gt = self.cv2_loader(self.gts[index], is_mask=True) + dataset = self.get_dataset_name_from_path(self.images[index]) + return self.process_image_gt(image, gt, dataset) + + def filter_files_and_get_ds_mean_and_std(self): + assert len(self.images) == len(self.gts) + images = [] + gts = [] + ds_mean = 0 + ds_std = 0 + for img_path, gt_path in zip(self.images, self.gts): + if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]): + continue + img = Image.open(img_path) + gt = Image.open(gt_path) + if img.size == gt.size: + images.append(img_path) + gts.append(gt_path) + ds_mean += np.array(img).mean() + ds_std += np.array(img).std() + self.images = images + self.gts = gts + self.mean = ds_mean / len(self.images) + self.std = ds_std / len(self.images) + + def rgb_loader(self, path): + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + def binary_loader(self, path): + # with open(path, 'rb') as f: + # img = Image.open(f) + # return img.convert('1') + img = cv2.imread(path, 0) + return img + + def cv2_loader(self, path, is_mask): + if is_mask: + img = cv2.imread(path, 0) + img[img > 0] = 1 + else: + img = cv2.cvtColor(cv2.imread( + path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + return img + + def resize(self, img, gt): + assert img.size == gt.size + w, h = img.size + if h < self.trainsize or w < self.trainsize: + h = max(h, self.trainsize) + w = max(w, self.trainsize) + return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) + else: + return img, gt + + def __len__(self): + # return 32 + return self.size + + +class SuperpixPolypDataset(PolypDataset): + def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): + self.trainsize = trainsize + self.augmentations = augmentations + self.datasets = datasets + self.image_size = image_size + # print(self.augmentations) + if image_root is not None and gt_root is not None: + self.images = [ + os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] + self.gts = [ + os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f] + # also look in subdirectories + for subdir in os.listdir(image_root): + # if not dir, continue + if not os.path.isdir(os.path.join(image_root, subdir)): + continue + subdir_image_root = os.path.join(image_root, subdir) + subdir_gt_root = os.path.join(gt_root, subdir) + self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( + subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) + self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( + subdir_gt_root) if f.endswith('.png')]) + + else: + self.images, self.gts = self.get_image_gt_pairs( + root, split="train" if train else "test", datasets=self.datasets) + self.images = sorted(self.images) + self.gts = sorted(self.gts) + if not 'VPS' in root: + self.filter_files_and_get_ds_mean_and_std() + if ds_mean is not None and ds_std is not None: + self.mean, self.std = ds_mean, ds_std + self.size = len(self.images) + self.train = train + self.sam_trans = sam_trans + if self.sam_trans is not None: + # sam trans takes care of norm + self.mean, self.std = 0 , 1 + + + def __getitem__(self, index): + image = self.cv2_loader(self.images[index], is_mask=False) + gt = self.cv2_loader(self.gts[index], is_mask=False) + gt = gt[:, :, 0] + fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_') + fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png') + fg = self.cv2_loader(fgpath, is_mask=True) + dataset = self.get_dataset_name_from_path(self.images[index]) + + # randomly choose a superpixels from the gt + gt[1-fg] = 0 + sp_id = random.choice(np.unique(gt)[1:]) + sp = (gt == sp_id).astype(np.uint8) + + + out = self.process_image_gt(image, gt, dataset) + support_image, support_sp, dataset = out["image"], out["label"], out["case"] + + out = self.process_image_gt(image, sp, dataset) + query_image, query_sp, dataset = out["image"], out["label"], out["case"] + + # TODO tile the masks to have 3 channels? + + support_bg_mask = 1 - support_sp + support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask} + + batch = {"support_images" : [[support_image]], + "support_mask" : [[support_masks]], + "query_images" : [query_image], + "query_labels" : [query_sp], + "scan_id" : [dataset] + } + + return batch + + +def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None): + transform_train, transform_test = get_polyp_transform() + image_root = './data/PolypDataset/TrainDataset/images/' + gt_root = './data/PolypDataset/TrainDataset/superpixels/' + ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, + augmentations=transform_train, + sam_trans=sam_trans, + image_size=image_size) + + return ds_train + +def get_polyp_dataset(image_size, sam_trans=None): + transform_train, transform_test = get_polyp_transform() + image_root = './data/PolypDataset/TrainDataset/images/' + gt_root = './data/PolypDataset/TrainDataset/masks/' + ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, + augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size) + image_root = './data/PolypDataset/TestDataset/test/images/' + gt_root = './data/PolypDataset/TestDataset/test/masks/' + ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False, + augmentations=transform_test, sam_trans=sam_trans, image_size=image_size) + return ds_train, ds_test + + +def get_tests_polyp_dataset(sam_trans): + transform_train, transform_test = get_polyp_transform() + + image_root = './data/polyp/TestDataset/Kvasir/images/' + gt_root = './data/polyp/TestDataset/Kvasir/masks/' + ds_Kvasir = PolypDataset( + image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) + + image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/' + gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/' + ds_ClinicDB = PolypDataset( + image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) + + image_root = './data/polyp/TestDataset/CVC-ColonDB/images/' + gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/' + ds_ColonDB = PolypDataset( + image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) + + image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/' + gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/' + ds_ETIS = PolypDataset( + image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) + + return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS + + +if __name__ == '__main__': + # create_train_val_test_split_for_polyps() + create_suppport_set_for_polyps() diff --git a/dataloaders/PolypTransforms.py b/dataloaders/PolypTransforms.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3981d098f0f296c3dc1d004aadcb017c3464db --- /dev/null +++ b/dataloaders/PolypTransforms.py @@ -0,0 +1,626 @@ +from __future__ import division +import torch +import math +import sys +import random +from PIL import Image + +try: + import accimage +except ImportError: + accimage = None +import numpy as np +import numbers +import types +import collections +import warnings + +from torchvision.transforms import functional as F + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + +__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad", + "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", + "RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop", + "ColorJitter", "RandomRotation", "RandomAffine", + "RandomPerspective"] + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, mask): + for t in self.transforms: + img, mask = t(img, mask) + return img, mask + + +class ToTensor(object): + def __call__(self, img, mask): + # return F.to_tensor(img), F.to_tensor(mask) + img = np.array(img) + img = torch.from_numpy(img).permute(2, 0, 1).float() # TODO add division by 255 to match torch.ToTensor()? + mask = torch.from_numpy(np.array(mask)).float() + return img, mask + + +class ToPILImage(object): + def __init__(self, mode=None): + self.mode = mode + + def __call__(self, img, mask): + return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode) + + +class Normalize(object): + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, img, mask): + return F.normalize(img, self.mean, self.std, self.inplace), mask + + +class Resize(object): + def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True): + assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + self.do_mask = do_mask + + def __call__(self, img, mask): + if self.do_mask: + return F.resize(img, self.size, Image.BICUBIC), F.resize(mask, self.size, Image.BICUBIC) + else: + return F.resize(img, self.size, Image.BICUBIC), mask + + +class CenterCrop(object): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img, mask): + return F.center_crop(img, self.size), F.center_crop(mask, self.size) + + +class Pad(object): + def __init__(self, padding, fill=0, padding_mode='constant'): + assert isinstance(padding, (numbers.Number, tuple)) + assert isinstance(fill, (numbers.Number, str, tuple)) + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + if isinstance(padding, Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def __call__(self, img, mask): + return F.pad(img, self.padding, self.fill, self.padding_mode), \ + F.pad(mask, self.padding, self.fill, self.padding_mode) + + +class Lambda(object): + def __init__(self, lambd): + assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" + self.lambd = lambd + + def __call__(self, img, mask): + return self.lambd(img), self.lambd(mask) + + +class Lambda_image(object): + def __init__(self, lambd): + assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" + self.lambd = lambd + + def __call__(self, img, mask): + return self.lambd(img), mask + + +class RandomTransforms(object): + def __init__(self, transforms): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + +class RandomApply(RandomTransforms): + def __init__(self, transforms, p=0.5): + super(RandomApply, self).__init__(transforms) + self.p = p + + def __call__(self, img, mask): + if self.p < random.random(): + return img, mask + for t in self.transforms: + img, mask = t(img, mask) + return img, mask + + +class RandomOrder(RandomTransforms): + def __call__(self, img, mask): + order = list(range(len(self.transforms))) + random.shuffle(order) + for i in order: + img, mask = self.transforms[i](img, mask) + return img, mask + + +class RandomChoice(RandomTransforms): + def __call__(self, img, mask): + t = random.choice(self.transforms) + return t(img, mask) + + +class RandomCrop(object): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + @staticmethod + def get_params(img, output_size): + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img, mask): + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + # pad the width if needed + if self.pad_if_needed and img.size[0] < self.size[1]: + img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and img.size[1] < self.size[0]: + img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + + return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, mask): + if random.random() < self.p: + return F.hflip(img), F.hflip(mask) + return img, mask + + +class RandomVerticalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, mask): + if random.random() < self.p: + return F.vflip(img), F.vflip(mask) + return img, mask + + +class RandomPerspective(object): + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): + self.p = p + self.interpolation = interpolation + self.distortion_scale = distortion_scale + + def __call__(self, img, mask): + if not F._is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if random.random() < self.p: + width, height = img.size + startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + return F.perspective(img, startpoints, endpoints, self.interpolation), \ + F.perspective(mask, startpoints, endpoints, Image.NEAREST) + return img, mask + + @staticmethod + def get_params(width, height, distortion_scale): + half_height = int(height / 2) + half_width = int(width / 2) + topleft = (random.randint(0, int(distortion_scale * half_width)), + random.randint(0, int(distortion_scale * half_height))) + topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), + random.randint(0, int(distortion_scale * half_height))) + botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), + random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) + botleft = (random.randint(0, int(distortion_scale * half_width)), + random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) + startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] + endpoints = [topleft, topright, botright, botleft] + return startpoints, endpoints + + +class RandomResizedCrop(object): + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if (in_ratio < min(ratio)): + w = img.size[0] + h = w / min(ratio) + elif (in_ratio > max(ratio)): + h = img.size[1] + w = h * max(ratio) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img, mask): + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \ + F.resized_crop(mask, i, j, h, w, self.size, Image.NEAREST) + + +class FiveCrop(object): + def __init__(self, size): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + + def __call__(self, img, mask): + return F.five_crop(img, self.size), F.five_crop(mask, self.size) + + +class TenCrop(object): + def __init__(self, size, vertical_flip=False): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + self.vertical_flip = vertical_flip + + def __call__(self, img, mask): + return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip) + + +class ColorJitter(object): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img, mask): + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return transform(img, mask) + + +class RandomRotation(object): + def __init__(self, degrees, resample=False, expand=False, center=None): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError("If degrees is a sequence, it must be of len 2.") + self.degrees = degrees + + self.resample = resample + self.expand = expand + self.center = center + + @staticmethod + def get_params(degrees): + angle = random.uniform(degrees[0], degrees[1]) + + return angle + + def __call__(self, img, mask): + angle = self.get_params(self.degrees) + + return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \ + F.rotate(mask, angle, Image.NEAREST, self.expand, self.center) + + +class RandomAffine(object): + def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, img_size): + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = random.uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + shear = random.uniform(shears[0], shears[1]) + else: + shear = 0.0 + + return angle, translations, scale, shear + + def __call__(self, img, mask): + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) + return F.affine(img, *ret, interpolation=Image.BILINEAR, fill=self.fillcolor), \ + F.affine(mask, *ret, interpolation=Image.NEAREST, fill=self.fillcolor) + + + +def get_cub_transform(): + transform_train = Compose([ + ToPILImage(), + Resize((256, 256)), + RandomHorizontalFlip(), + RandomAffine(22, scale=(0.75, 1.25)), + ToTensor(), + Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225]) + ]) + transform_test = Compose([ + ToPILImage(), + Resize((256, 256)), + ToTensor(), + Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225]) + ]) + return transform_train, transform_test + + +def get_glas_transform(): + transform_train = Compose([ + ToPILImage(), + # Resize((256, 256)), + ColorJitter(brightness=0.2, + contrast=0.2, + saturation=0.2, + hue=0.1), + RandomHorizontalFlip(), + RandomAffine(5, scale=(0.75, 1.25)), + ToTensor(), + # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225]) + ]) + transform_test = Compose([ + ToPILImage(), + # Resize((256, 256)), + ToTensor(), + # Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225]) + ]) + return transform_train, transform_test + +# def get_glas_transform(): +# transform_train = Compose([ +# ToPILImage(), +# Resize((256, 256)), +# ColorJitter(brightness=0.2, +# contrast=0.2, +# saturation=0.2, +# hue=0.1), +# RandomHorizontalFlip(), +# RandomAffine(5, scale=(0.75, 1.25)), +# ToTensor(), +# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225]) +# ]) +# transform_test = Compose([ +# ToPILImage(), +# Resize((256, 256)), +# ToTensor(), +# Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) +# ]) +# return transform_train, transform_test + + +def get_monu_transform(args): + Idim = int(args['Idim']) + transform_train = Compose([ + ToPILImage(), + # Resize((Idim, Idim)), + ColorJitter(brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1), + RandomHorizontalFlip(), + RandomAffine(int(args['rotate']), scale=(float(args['scale1']), float(args['scale2']))), + ToTensor(), + # Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78]) + ]) + transform_test = Compose([ + ToPILImage(), + # Resize((Idim, Idim)), + ToTensor(), + # Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78]) + ]) + return transform_train, transform_test + + +def get_polyp_transform(): + transform_train = Compose([ + # Resize((352, 352)), + ToPILImage(), + ColorJitter(brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1), + RandomVerticalFlip(), + RandomHorizontalFlip(), + RandomAffine(90, scale=(0.75, 1.25)), + ToTensor(), + # Normalize([105.61, 63.69, 45.67], + # [83.08, 55.86, 42.59]) + ]) + transform_test = Compose([ + # Resize((352, 352)), + ToPILImage(), + ToTensor(), + # Normalize([105.61, 63.69, 45.67], + # [83.08, 55.86, 42.59]) + ]) + return transform_train, transform_test + + +def get_polyp_support_train_transform(): + transform_train = Compose([ + ColorJitter(brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1), + RandomVerticalFlip(), + RandomHorizontalFlip(), + RandomAffine(90, scale=(0.75, 1.25)), + ]) + + return transform_train \ No newline at end of file diff --git a/dataloaders/SimpleDataset.py b/dataloaders/SimpleDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e12092b638bec466fbf40770b3ca46d8707fd54f --- /dev/null +++ b/dataloaders/SimpleDataset.py @@ -0,0 +1,61 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +""" +simple dataset, gets the images and masks as list together with a transform function that +shoudl receive both the image and the mask. +loop means how many times to loop the dataset per epoch +""" + +class SimpleDataset(torch.utils.data.Dataset): + def __init__(self, image_list, mask_list, transform=None, norm_func=None, loops=10, modality="", debug=False, image_size=None): + self.image_list = image_list + if image_size is not None: + if len(image_size) == 1: + image_size = (image_size, image_size) + self.image_size = image_size + else: + self.image_size = image_list[0].shape[-2:] + self.mask_list = mask_list + self.transform = transform + self.norm_func = norm_func + self.loops = loops + self.modality = modality + self.debug = debug + + def __len__(self): + return len(self.image_list) * self.loops + + def __getitem__(self, idx): + idx = idx % (len(self.image_list)) + image = self.image_list[idx].numpy() + mask = self.mask_list[idx].to(dtype=torch.uint8).numpy() + if self.modality == "CT": + image = image.astype(np.uint8) + if self.transform: + image, mask = self.transform(image, mask) + else: + # mask = np.repeat(mask[..., np.newaxis], 3, axis=-1) + if self.transform: + image, mask = self.transform(image, mask) + + if self.norm_func: + image = self.norm_func(image) + + mask[mask != 0] = 1 + + if self.image_size != image.shape[-2:]: + image = torch.nn.functional.interpolate(torch.tensor(image).unsqueeze(0), self.image_size, mode='bilinear').squeeze(0) + mask = torch.nn.functional.interpolate(torch.tensor(mask).unsqueeze(0).unsqueeze(0), self.image_size, mode='nearest').squeeze(0).squeeze(0) + + # plot image and mask + if self.debug: + fig = plt.figure() + plt.imshow((image[0]- image.min()) / (image.max() - image.min())) + plt.imshow(mask, alpha=0.5) + plt.savefig("debug/support_image_mask.png") + plt.close(fig) + + image_size = torch.tensor(tuple(image.shape[-2:])) + return image, mask \ No newline at end of file diff --git a/dataloaders/__init__.py b/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataloaders/augutils.py b/dataloaders/augutils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7120afc3fbc23b2a1427b6f71e59be2e7f4cad4 --- /dev/null +++ b/dataloaders/augutils.py @@ -0,0 +1,224 @@ +''' +Utilities for augmentation. Partly credit to Dr. Jo Schlemper +''' +from os.path import join + +import torch +import numpy as np +import torchvision.transforms as deftfx +import dataloaders.image_transforms as myit +import copy +from util.consts import IMG_SIZE +import time +import functools + + +def get_sabs_aug(input_size, use_3d=False): + sabs_aug = { + # turn flipping off as medical data has fixed orientations + 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25}, + 'affine': { + 'rotate': 5, + 'shift': (5, 5), + 'shear': 5, + 'scale': (0.9, 1.2), + }, + 'elastic': {'alpha': 10, 'sigma': 5}, + 'patch': input_size, + 'reduce_2d': True, + '3d': use_3d, + 'gamma_range': (0.5, 1.5) + } + return sabs_aug + + +def get_sabs_augv3(input_size): + sabs_augv3 = { + 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25}, + 'affine': { + 'rotate': 30, + 'shift': (30, 30), + 'shear': 30, + 'scale': (0.8, 1.3), + }, + 'elastic': {'alpha': 20, 'sigma': 5}, + 'patch': input_size, + 'reduce_2d': True, + 'gamma_range': (0.2, 1.8) + } + return sabs_augv3 + + +def get_aug(which_aug, input_size): + if which_aug == 'sabs_aug': + return get_sabs_aug(input_size) + elif which_aug == 'aug_v3': + return get_sabs_augv3(input_size) + else: + raise NotImplementedError + +# augs = { +# 'sabs_aug': get_sabs_aug, +# 'aug_v3': get_sabs_augv3, # more aggresive +# } + + +def get_geometric_transformer(aug, order=3): + """order: interpolation degree. Select order=0 for augmenting segmentation """ + affine = aug['aug'].get('affine', 0) + alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha'] + sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma'] + flip = aug['aug'].get( + 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125}) + + tfx = [] + if 'flip' in aug['aug']: + tfx.append(myit.RandomFlip3D(**flip)) + + if 'affine' in aug['aug']: + tfx.append(myit.RandomAffine(affine.get('rotate'), + affine.get('shift'), + affine.get('shear'), + affine.get('scale'), + affine.get('scale_iso', True), + order=order)) + + if 'elastic' in aug['aug']: + tfx.append(myit.ElasticTransform(alpha, sigma)) + input_transform = deftfx.Compose(tfx) + return input_transform + + +def get_geometric_transformer_3d(aug, order=3): + """order: interpolation degree. Select order=0 for augmenting segmentation """ + affine = aug['aug'].get('affine', 0) + alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha'] + sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma'] + flip = aug['aug'].get( + 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125}) + + tfx = [] + if 'flip' in aug['aug']: + tfx.append(myit.RandomFlip3D(**flip)) + + if 'affine' in aug['aug']: + tfx.append(myit.RandomAffine(affine.get('rotate'), + affine.get('shift'), + affine.get('shear'), + affine.get('scale'), + affine.get('scale_iso', True), + order=order, + use_3d=True)) + + if 'elastic' in aug['aug']: + tfx.append(myit.ElasticTransform(alpha, sigma)) + input_transform = deftfx.Compose(tfx) + return input_transform + + +def gamma_transform(img, aug): + gamma_range = aug['aug']['gamma_range'] + if isinstance(gamma_range, tuple): + gamma = np.random.rand() * \ + (gamma_range[1] - gamma_range[0]) + gamma_range[0] + cmin = img.min() + irange = (img.max() - cmin + 1e-5) + + img = img - cmin + 1e-5 + img = irange * np.power(img * 1.0 / irange, gamma) + img = img + cmin + + elif gamma_range == False: + pass + else: + raise ValueError( + "Cannot identify gamma transform range {}".format(gamma_range)) + return img + + +def get_intensity_transformer(aug): + """some basic intensity transforms""" + return functools.partial(gamma_transform, aug=aug) + + +def transform_with_label(aug): + """ + Doing image geometric transform + Proposed image to have the following configurations + [H x W x C + CL] + Where CL is the number of channels for the label. It is NOT in one-hot form + """ + + geometric_tfx = get_geometric_transformer(aug) + intensity_tfx = get_intensity_transformer(aug) + + def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs): + """ + Args + comp: a numpy array with shape [H x W x C + c_label] + c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) + nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label + + """ + comp = copy.deepcopy(comp) + if (use_onehot is True) and (c_label != 1): + raise NotImplementedError( + "Only allow compact label, also the label can only be 2d") + assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label" + + # geometric transform + _label = comp[..., c_img] + _h_label = np.float32(np.arange(nclass) == (_label[..., None])) + # _h_label = np.float32(_label[..., None]) + comp = np.concatenate([comp[..., :c_img], _h_label], -1) + comp = geometric_tfx(comp) + # round one_hot labels to 0 or 1 + t_label_h = comp[..., c_img:] + t_label_h = np.rint(t_label_h) + assert t_label_h.max() <= 1 + t_img = comp[..., 0: c_img] + + # intensity transform + t_img = intensity_tfx(t_img) + + if use_onehot is True: + t_label = t_label_h + else: + t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1) + return t_img, t_label + + return transform + + +def transform(scan, label, nclass, geometric_tfx, intensity_tfx): + """ + Args + scan: a numpy array with shape [D x H x W x C] + label: a numpy array with shape [D x H x W x 1] + """ + assert len(scan.shape) == 4, "Input scan must be 4D" + if len(label.shape) == 3: + label = np.expand_dims(label, -1) + + # geometric transform + comp = copy.deepcopy(np.concatenate( + [scan, label], -1)) # [D x H x W x C + 1] + _label = comp[..., -1] + _h_label = np.float32(np.arange(nclass) == (_label[..., None])) + comp = np.concatenate([comp[..., :-1], _h_label], -1) + # change comp to be H x W x D x C + 1 + comp = np.transpose(comp, (1, 2, 0, 3)) + comp = geometric_tfx(comp) + t_label_h = comp[..., 1:] + t_label_h = np.rint(t_label_h) + assert t_label_h.max() <= 1 + t_img = comp[..., 0:1] + + # intensity transform + t_img = intensity_tfx(t_img) + return t_img, t_label_h + + +def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx): + return transform(scan, label, nclass, geometric_tfx, intensity_tfx) + diff --git a/dataloaders/common.py b/dataloaders/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d65ed55f33aff2684f01cd42208287677d2d0f6a --- /dev/null +++ b/dataloaders/common.py @@ -0,0 +1,263 @@ +""" +Dataset classes for common uses +Extended from vanilla PANet code by Wang et al. +""" +import random +import torch + +from torch.utils.data import Dataset + +class BaseDataset(Dataset): + """ + Base Dataset + Args: + base_dir: + dataset directory + """ + def __init__(self, base_dir): + self._base_dir = base_dir + self.aux_attrib = {} + self.aux_attrib_args = {} + self.ids = [] # must be overloaded in subclass + + def add_attrib(self, key, func, func_args): + """ + Add attribute to the data sample dict + + Args: + key: + key in the data sample dict for the new attribute + e.g. sample['click_map'], sample['depth_map'] + func: + function to process a data sample and create an attribute (e.g. user clicks) + func_args: + extra arguments to pass, expected a dict + """ + if key in self.aux_attrib: + raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key)) + else: + self.set_attrib(key, func, func_args) + + def set_attrib(self, key, func, func_args): + """ + Set attribute in the data sample dict + + Args: + key: + key in the data sample dict for the new attribute + e.g. sample['click_map'], sample['depth_map'] + func: + function to process a data sample and create an attribute (e.g. user clicks) + func_args: + extra arguments to pass, expected a dict + """ + self.aux_attrib[key] = func + self.aux_attrib_args[key] = func_args + + def del_attrib(self, key): + """ + Remove attribute in the data sample dict + + Args: + key: + key in the data sample dict + """ + self.aux_attrib.pop(key) + self.aux_attrib_args.pop(key) + + def subsets(self, sub_ids, sub_args_lst=None): + """ + Create subsets by ids + + Args: + sub_ids: + a sequence of sequences, each sequence contains data ids for one subset + sub_args_lst: + a list of args for some subset-specific auxiliary attribute function + """ + + indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids] + if sub_args_lst is not None: + subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args) + for index, args in zip(indices, sub_args_lst)] + else: + subsets = [Subset(dataset=self, indices=index) for index in indices] + return subsets + + def __len__(self): + pass + + def __getitem__(self, idx): + pass + + +class ReloadPairedDataset(Dataset): + """ + Make pairs of data from dataset + Eable only loading part of the entire data in each epoach and then reload to the next part + Args: + datasets: + source datasets, expect a list of Dataset. + Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan + n_elements: + number of elements in a pair + curr_max_iters: + number of pairs in an epoch + pair_based_transforms: + some transformation performed on a pair basis, expect a list of functions, + each function takes a pair sample and return a transformed one. + """ + def __init__(self, datasets, n_elements, curr_max_iters, + pair_based_transforms=None): + super().__init__() + self.datasets = datasets + self.n_datasets = len(self.datasets) + self.n_data = [len(dataset) for dataset in self.datasets] + self.n_elements = n_elements + self.curr_max_iters = curr_max_iters + self.pair_based_transforms = pair_based_transforms + self.update_index() + + def update_index(self): + """ + update the order of batches for the next episode + """ + + # update number of elements for each subset + if hasattr(self, 'indices'): + n_data_old = self.n_data # DEBUG + self.n_data = [len(dataset) for dataset in self.datasets] + + if isinstance(self.n_elements, list): + self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use + for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use + for i_iter in range(self.curr_max_iters)] # sample iterations + + elif self.n_elements > self.n_datasets: + raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets") + else: + self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx])) + for dataset_idx in random.sample(range(self.n_datasets), + k=n_elements)] + for i in range(curr_max_iters)] + + def __len__(self): + return self.curr_max_iters + + def __getitem__(self, idx): + sample = [self.datasets[dataset_idx][data_idx] + for dataset_idx, data_idx in self.indices[idx]] + if self.pair_based_transforms is not None: + for transform, args in self.pair_based_transforms: + sample = transform(sample, **args) + return sample + +class Subset(Dataset): + """ + Subset of a dataset at specified indices. Used for seperating a dataset by class in our context + + Args: + dataset: + The whole Dataset + indices: + Indices of samples of the current class in the entire dataset + sub_attrib_args: + Subset-specific arguments for attribute functions, expected a dict + """ + def __init__(self, dataset, indices, sub_attrib_args=None): + self.dataset = dataset + self.indices = indices + self.sub_attrib_args = sub_attrib_args + + def __getitem__(self, idx): + if self.sub_attrib_args is not None: + for key in self.sub_attrib_args: + # Make sure the dataset already has the corresponding attributes + # Here we only make the arguments subset dependent + # (i.e. pass different arguments for each subset) + self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key]) + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + +class ValidationDataset(Dataset): + """ + Dataset for validation + + Args: + dataset: + source dataset with a __getitem__ method + test_classes: + test classes + npart: int. number of parts, used for evaluation when assigning support images + + """ + def __init__(self, dataset, test_classes: list, npart: int): + super().__init__() + self.dataset = dataset + self.__curr_cls = None + self.test_classes = test_classes + self.dataset.aux_attrib = None + self.npart = npart + + def set_curr_cls(self, curr_cls): + assert curr_cls in self.test_classes + self.__curr_cls = curr_cls + + def get_curr_cls(self): + return self.__curr_cls + + def read_dataset(self): + """ + override original read_dataset to allow reading with z_margin + """ + raise NotImplementedError + + def __len__(self): + return len(self.dataset) + + def label_strip(self, label): + """ + mask unrelated labels out + """ + out = torch.where(label == self.__curr_cls, + torch.ones_like(label), torch.zeros_like(label)) + return out + + def __getitem__(self, idx): + if self.__curr_cls is None: + raise Exception("Please initialize current class first") + + sample = self.dataset[idx] + sample["label"] = self.label_strip( sample["label"] ) + sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy() + + labelname = self.dataset.all_label_names[self.__curr_cls] + z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) + z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) + sample["z_min"], sample["z_max"] = z_min, z_max + try: + part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart)) + except: + part_assign = 0 + # print("###### DATASET: support only has one valid slice ######") + if part_assign < 0: + part_assign = 0 + elif part_assign >= self.npart: + part_assign = self.npart - 1 + sample["part_assign"] = part_assign + sample["case"] = sample["scan_id"] + + return sample + + def get_support_set(self, config, n_support=3): + support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"]) + + support_images = [img for way in support_batched["support_images"] for img in way] + support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way] + support_scan_id = self.dataset.potential_support_sid + return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id} + + + diff --git a/dataloaders/dataset_utils.py b/dataloaders/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d6f5752f144df91fbed9336c9693b1810e19871 --- /dev/null +++ b/dataloaders/dataset_utils.py @@ -0,0 +1,128 @@ +""" +Utils for datasets +""" +import functools +import numpy as np + +import os +import sys +import nibabel as nib +import numpy as np +import pdb +import SimpleITK as sitk + +DATASET_INFO = { + "CHAOST2": { + 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], + 'REAL_LABEL_NAME': ["BG", "LIVER", "RK", "LK", "SPLEEN"], + '_SEP': [0, 4, 8, 12, 16, 20], + 'MODALITY': 'MR', + 'LABEL_GROUP': { + 'pa_all': set(range(1, 5)), + 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes + 1: set([2, 3]), # lower_abdomen + }, + }, + + "SABS": { + 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], + + 'REAL_LABEL_NAME': ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC",\ + "PS_VEIN", "PANCREAS", "AG_R", "AG_L"], + '_SEP': [0, 6, 12, 18, 24, 30], + 'MODALITY': 'CT', + 'LABEL_GROUP':{ + 'pa_all': set( [1,2,3,6] ), + 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing + 1: set( [2,3] ), # lower_abdomen + } + }, + "LITS17": { + 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], + + 'REAL_LABEL_NAME': ["BGD", "LIVER", "TUMOR"], + '_SEP': [0, 26, 52, 78, 104], + 'MODALITY': 'CT', + 'LABEL_GROUP':{ + 'pa_all': set( [1 , 2] ), + 0: set([1 ] ), # liver + 1: set( [ 2] ), # tumor + 2: set([1,2]) # liver + tumor + } + + } + +} + +def read_nii_bysitk(input_fid, peel_info = False): + """ read nii to numpy through simpleitk + + peelinfo: taking direction, origin, spacing and metadata out + """ + img_obj = sitk.ReadImage(input_fid) + img_np = sitk.GetArrayFromImage(img_obj) + if peel_info: + info_obj = { + "spacing": img_obj.GetSpacing(), + "origin": img_obj.GetOrigin(), + "direction": img_obj.GetDirection(), + "array_size": img_np.shape + } + return img_np, info_obj + else: + return img_np + + +def get_CT_statistics(scan_fids): + """ + As CT are quantitative, get mean and std for CT images for image normalizing + As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading + """ + total_val = 0 + n_pix = 0 + for fid in scan_fids: + in_img = read_nii_bysitk(fid) + total_val += in_img.sum() + n_pix += np.prod(in_img.shape) + del in_img + meanval = total_val / n_pix + + total_var = 0 + for fid in scan_fids: + in_img = read_nii_bysitk(fid) + total_var += np.sum((in_img - meanval) ** 2 ) + del in_img + var_all = total_var / n_pix + + global_std = var_all ** 0.5 + + return meanval, global_std + +def MR_normalize(x_in): + return (x_in - x_in.mean()) / x_in.std() + +def CT_normalize(x_in, ct_mean, ct_std): + """ + Normalizing CT images, based on global statistics + """ + return (x_in - ct_mean) / ct_std + +def get_normalize_op(modality, fids, ct_mean=None, ct_std=None): + """ + As title + Args: + modality: CT or MR + fids: fids for the fold + """ + if modality == 'MR': + return MR_normalize + + elif modality == 'CT': + if ct_mean is None or ct_std is None: + ct_mean, ct_std = get_CT_statistics(fids) + # debug + print(f'###### DEBUG_DATASET CT_STATS NORMALIZED MEAN {ct_mean} STD {ct_std} ######') + + return functools.partial(CT_normalize, ct_mean=ct_mean, ct_std=ct_std) + + diff --git a/dataloaders/dev_customized_med.py b/dataloaders/dev_customized_med.py new file mode 100644 index 0000000000000000000000000000000000000000..406cc63e017df7d00a93fccaf54293ada1374d91 --- /dev/null +++ b/dataloaders/dev_customized_med.py @@ -0,0 +1,250 @@ +""" +Customized dataset. Extended from vanilla PANet script by Wang et al. +""" + +import os +import random +import torch +import numpy as np + +from dataloaders.common import ReloadPairedDataset, ValidationDataset +from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset + +def attrib_basic(_sample, class_id): + """ + Add basic attribute + Args: + _sample: data sample + class_id: class label asscociated with the data + (sometimes indicting from which subset the data are drawn) + """ + return {'class_id': class_id} + +def getMaskOnly(label, class_id, class_ids): + """ + Generate FG/BG mask from the segmentation mask + + Args: + label: + semantic mask + scribble: + scribble mask + class_id: + semantic class of interest + class_ids: + all class id in this episode + """ + # Dense Mask + fg_mask = torch.where(label == class_id, + torch.ones_like(label), torch.zeros_like(label)) + bg_mask = torch.where(label != class_id, + torch.ones_like(label), torch.zeros_like(label)) + for class_id in class_ids: + bg_mask[label == class_id] = 0 + + return {'fg_mask': fg_mask, + 'bg_mask': bg_mask} + +def getMasks(*args, **kwargs): + raise NotImplementedError + +def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True): + """ + Postprocess paired sample for fewshot settings + For now only 1-way is tested but we leave multi-way possible (inherited from original PANet) + + Args: + paired_sample: + data sample from a PairedDataset + n_ways: + n-way few-shot learning + n_shots: + n-shot few-shot learning + cnt_query: + number of query images for each class in the support set + coco: + MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension + mask_only: + only give masks and no scribbles/ instances. Suitable for medical images (for now) + """ + if not mask_only: + raise NotImplementedError + ###### Compose the support and query image list ###### + cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries + + # support class ids + class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query) + + # support images + support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)] + for i in range(n_ways)] # fetch support images for each class + + # support image labels + if coco: + support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]] + for j in range(n_shots)] for i in range(n_ways)] + else: + support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)] + for i in range(n_ways)] + + if not mask_only: + support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)] + for i in range(n_ways)] + support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)] + for i in range(n_ways)] + else: + support_insts = [] + + # query images, masks and class indices + query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways) + for j in range(cnt_query[i])] + if coco: + query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]] + for i in range(n_ways) for j in range(cnt_query[i])] + else: + query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways) + for j in range(cnt_query[i])] + query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1 + for x in set(np.unique(query_label)) & set(class_ids)]) + for query_label in query_labels] + + ###### Generate support image masks ###### + if not mask_only: + support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot], + class_ids[way], class_ids) + for shot in range(n_shots)] for way in range(n_ways)] + else: + support_mask = [[getMaskOnly(support_labels[way][shot], + class_ids[way], class_ids) + for shot in range(n_shots)] for way in range(n_ways)] + + ###### Generate query label (class indices in one episode, i.e. the ground truth)###### + query_labels_tmp = [torch.zeros_like(x) for x in query_labels] + for i, query_label_tmp in enumerate(query_labels_tmp): + query_label_tmp[query_labels[i] == 255] = 255 + for j in range(n_ways): + query_label_tmp[query_labels[i] == class_ids[j]] = j + 1 + + ###### Generate query mask for each semantic class (including BG) ###### + # BG class + query_masks = [[torch.where(query_label == 0, + torch.ones_like(query_label), + torch.zeros_like(query_label))[None, ...],] + for query_label in query_labels] + # Other classes in query image + for i, query_label in enumerate(query_labels): + for idx in query_cls_idx[i][1:]: + mask = torch.where(query_label == class_ids[idx - 1], + torch.ones_like(query_label), + torch.zeros_like(query_label))[None, ...] + query_masks[i].append(mask) + + + return {'class_ids': class_ids, + 'support_images': support_images, + 'support_mask': support_mask, + 'support_inst': support_insts, # leave these interfaces + 'support_scribbles': support_scribbles, + + 'query_images': query_images, + 'query_labels': query_labels_tmp, + 'query_masks': query_masks, + 'query_cls_idx': query_cls_idx, + } + + +def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load, + transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs): + """ + Dataset wrapper + Args: + dataset_name: + indicates what dataset to use + base_dir: + dataset directory + mode: + which mode to use + choose from ('train', 'val', 'trainval', 'trainaug') + idx_split: + index of split + scan_per_load: + number of scans to load into memory as the dataset is large + use that together with reload_buffer + transforms: + transformations to be performed on images/masks + act_labels: + active labels involved in training process. Should be a subset of all labels + n_ways: + n-way few-shot learning, should be no more than # of object class labels + n_shots: + n-shot few-shot learning + max_iters_per_load: + number of pairs per load (epoch size) + n_queries: + number of query images + fix_parent_len: + fixed length of the parent dataset + """ + med_set = ManualAnnoDataset + + + mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\ + scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\ + exclude_list = exclude_list, **kwargs) + + mydataset.add_attrib('basic', attrib_basic, {}) + + # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside + subsets = mydataset.subsets([{'basic': {'class_id': ii}} + for ii, _ in enumerate(mydataset.label_name)]) + + # Choose the classes of queries + cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways) + # Number of queries for each way + # Set the number of images for each class + n_elements = [n_shots + x for x in cnt_query] # supports + [i] queries + # Create paired dataset. We do not include background. + paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load, + pair_based_transforms=[ + (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots, + 'cnt_query': cnt_query, 'mask_only': True})]) + return paired_data, mydataset + +def update_loader_dset(loader, parent_set): + """ + Update data loader and the parent dataset behind + Args: + loader: actual dataloader + parent_set: parent dataset which actually stores the data + """ + parent_set.reload_buffer() + loader.dataset.update_index() + print(f'###### Loader and dataset have been updated ######' ) + +def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs): + """ + validation set for med images + Args: + dataset_name: + indicates what dataset to use + base_dir: + SABS dataset directory + mode: (original split) + which split to use + choose from ('train', 'val', 'trainval', 'trainaug') + idx_split: + index of split + scan_per_batch: + number of scans to load into memory as the dataset is large + use that together with reload_buffer + act_labels: + actual labels involved in training process. Should be a subset of all labels + npart: number of chunks for splitting a 3d volume + nsup: number of support scans, equivalent to nshot + """ + mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs) + mydataset.add_attrib('basic', attrib_basic, {}) + + valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart) + + return valset, mydataset \ No newline at end of file diff --git a/dataloaders/image_transforms.py b/dataloaders/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..80a2a417c080953ebb9fd2abb67ef91806a12c9e --- /dev/null +++ b/dataloaders/image_transforms.py @@ -0,0 +1,362 @@ +""" +Image transforms functions for data augmentation +Credit to Dr. Jo Schlemper +""" + +from collections.abc import Sequence +import cv2 +import numpy as np +import scipy +from scipy.ndimage.filters import gaussian_filter +from scipy.ndimage.interpolation import map_coordinates +from numpy.lib.stride_tricks import as_strided +import numpy as np +import cv2 +from scipy.ndimage import map_coordinates +from numpy.lib.stride_tricks import as_strided +from multiprocessing import Pool +import albumentations as A +import time + +###### UTILITIES ###### +def random_num_generator(config, random_state=np.random): + if config[0] == 'uniform': + ret = random_state.uniform(config[1], config[2], 1)[0] + elif config[0] == 'lognormal': + ret = random_state.lognormal(config[1], config[2], 1)[0] + else: + #print(config) + raise Exception('unsupported format') + return ret + +def get_translation_matrix(translation): + """ translation: [tx, ty] """ + tx, ty = translation + translation_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + return translation_matrix + + + +def get_rotation_matrix(rotation, input_shape, centred=True): + theta = np.pi / 180 * np.array(rotation) + if centred: + rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1) + rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) + else: + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + return rotation_matrix + +def get_zoom_matrix(zoom, input_shape, centred=True): + zx, zy = zoom + if centred: + zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0]) + zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) + else: + zoom_matrix = np.array([[zx, 0, 0], + [0, zy, 0], + [0, 0, 1]]) + return zoom_matrix + +def get_shear_matrix(shear_angle): + theta = (np.pi * shear_angle) / 180 + shear_matrix = np.array([[1, -np.sin(theta), 0], + [0, np.cos(theta), 0], + [0, 0, 1]]) + return shear_matrix + +###### AFFINE TRANSFORM ###### +class RandomAffine(object): + """Apply random affine transformation on a numpy.ndarray (H x W x C) + Comment by co1818: this is still doing affine on 2d (H x W plane). + A same transform is applied to all C channels + + Parameter: + ---------- + + alpha: Range [0, 4] seems good for small images + + order: interpolation method (c.f. opencv) + """ + + def __init__(self, + rotation_range=None, + translation_range=None, + shear_range=None, + zoom_range=None, + zoom_keep_aspect=False, + interp='bilinear', + use_3d=False, + order=3): + """ + Perform an affine transforms. + + Arguments + --------- + rotation_range : one integer or float + image will be rotated randomly between (-degrees, degrees) + + translation_range : (x_shift, y_shift) + shifts in pixels + + *NOT TESTED* shear_range : float + image will be sheared randomly between (-degrees, degrees) + + zoom_range : (zoom_min, zoom_max) + list/tuple with two floats between [0, infinity). + first float should be less than the second + lower and upper bounds on percent zoom. + Anything less than 1.0 will zoom in on the image, + anything greater than 1.0 will zoom out on the image. + e.g. (0.7, 1.0) will only zoom in, + (1.0, 1.4) will only zoom out, + (0.7, 1.4) will randomly zoom in or out + """ + + self.rotation_range = rotation_range + self.translation_range = translation_range + self.shear_range = shear_range + self.zoom_range = zoom_range + self.zoom_keep_aspect = zoom_keep_aspect + self.interp = interp + self.order = order + self.use_3d = use_3d + + def build_M(self, input_shape): + tfx = [] + final_tfx = np.eye(3) + if self.rotation_range: + rot = np.random.uniform(-self.rotation_range, self.rotation_range) + tfx.append(get_rotation_matrix(rot, input_shape)) + if self.translation_range: + tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) + ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) + tfx.append(get_translation_matrix((tx,ty))) + if self.shear_range: + rot = np.random.uniform(-self.shear_range, self.shear_range) + tfx.append(get_shear_matrix(rot)) + if self.zoom_range: + sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + if self.zoom_keep_aspect: + sy = sx + else: + sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + + tfx.append(get_zoom_matrix((sx, sy), input_shape)) + + for tfx_mat in tfx: + final_tfx = np.dot(tfx_mat, final_tfx) + + return final_tfx.astype(np.float32) + + def __call__(self, image): + # build matrix + input_shape = image.shape[:2] + M = self.build_M(input_shape) + + res = np.zeros_like(image) + #if isinstance(self.interp, Sequence): + if type(self.order) is list or type(self.order) is tuple: + for i, intp in enumerate(self.order): + if self.use_3d: + res[..., i] = affine_transform_3d_via_M(image[..., i], M[:2], interp=intp) + else: + res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) + else: + # squeeze if needed + orig_shape = image.shape + image_s = np.squeeze(image) + if self.use_3d: + res = affine_transform_3d_via_M(image_s, M[:2], interp=self.order) + else: + res = affine_transform_via_M(image_s, M[:2], interp=self.order) + res = res.reshape(orig_shape) + + #res = affine_transform_via_M(image, M[:2], interp=self.order) + + return res + +def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): + imshape = image.shape + shape_size = imshape[:2] + + # Random affine + warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], + flags=interp, borderMode=borderMode) + + #print(imshape, warped.shape) + + warped = warped[..., np.newaxis].reshape(imshape) + + return warped + +def affine_transform_3d_via_M(vol, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): + """ + vol should be of shape (nx, ny, n1, ..., nm) + """ + # go over slice slice + res = np.zeros_like(vol) + for i in range(vol.shape[2]): + res[:, :, i] = affine_transform_via_M(vol[:,:,i], M, borderMode=borderMode, interp=interp) + + return res + + +###### ELASTIC TRANSFORM ###### +def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): + """Elastic deformation of image as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + assert image.ndim == 3 + shape = image.shape[:2] + + dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), + sigma, mode="constant", cval=0) * alpha + dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), + sigma, mode="constant", cval=0) * alpha + + x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') + indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] + result = np.empty_like(image) + for i in range(image.shape[2]): + result[:, :, i] = map_coordinates( + image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) + return result + +def elastic_transform_nd_3d(image, **kwargs): + """ + image_w_mask should be of shape (nx, ny, nz, 3) + """ + image_w_mask = image + start_time = time.time() + elastic_transform = A.ElasticTransform(alpha=10, sigma=20, alpha_affine=15, interpolation=1, border_mode=4, always_apply=True, p=0.5) + # print(f"elastic transform initilization took {time.time() - start_time} seconds") + img = image_w_mask[..., 0] + label = image_w_mask[..., -1] + transformed = elastic_transform(image=img, mask=label) + t_img = transformed['image'][..., np.newaxis] + t_mask = transformed['mask'][..., np.newaxis] + t_mask_bg = 1 - t_mask + t_mask = np.concatenate([t_mask_bg, t_mask], axis=-1) + + comp = np.concatenate([t_img, t_mask], axis=-1) + return comp + +def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): + """Expects data to be (nx, ny, n1 ,..., nm) + params: + ------ + + alpha: + the scaling parameter. + E.g.: alpha=2 => distorts images up to 2x scaling + + sigma: + standard deviation of gaussian filter. + E.g. + low (sig~=1e-3) => no smoothing, pixelated. + high (1/5 * imsize) => smooth, more like affine. + very high (1/2*im_size) => translation + """ + + if random_state is None: + random_state = np.random.RandomState(None) + + shape = image.shape + imsize = shape[:2] + dim = shape[2:] + + # Random affine + blur_size = int(4*sigma) | 1 + dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, + ksize=(blur_size, blur_size), sigmaX=sigma) * alpha + dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, + ksize=(blur_size, blur_size), sigmaX=sigma) * alpha + + # use as_strided to copy things over across n1...nn channels + dx = as_strided(dx.astype(np.float32), + strides=(0,) * len(dim) + (4*shape[1], 4), + shape=dim+(shape[0], shape[1])) + dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) + + dy = as_strided(dy.astype(np.float32), + strides=(0,) * len(dim) + (4*shape[1], 4), + shape=dim+(shape[0], shape[1])) + dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) + + coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) + indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], + [dy, dx] + [0] * len(dim))] + + if lazy: + return indices + res = map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) + return res + +class ElasticTransform(object): + """Apply elastic transformation on a numpy.ndarray (H x W x C) + """ + + def __init__(self, alpha, sigma, order=1): + self.alpha = alpha + self.sigma = sigma + self.order = order + + def __call__(self, image): + if isinstance(self.alpha, Sequence): + alpha = random_num_generator(self.alpha) + else: + alpha = self.alpha + if isinstance(self.sigma, Sequence): + sigma = random_num_generator(self.sigma) + else: + sigma = self.sigma + return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) + +class RandomFlip3D(object): + + def __init__(self, h=True, v=True, t=True, p=0.5): + """ + Randomly flip an image horizontally and/or vertically with + some probability. + + Arguments + --------- + h : boolean + whether to horizontally flip w/ probability p + + v : boolean + whether to vertically flip w/ probability p + + p : float between [0,1] + probability with which to apply allowed flipping operations + """ + self.horizontal = h + self.vertical = v + self.depth = t + self.p = p + + def __call__(self, x, y=None): + # horizontal flip with p = self.p + if self.horizontal: + if np.random.random() < self.p: + x = x[::-1, ...] + + # vertical flip with p = self.p + if self.vertical: + if np.random.random() < self.p: + x = x[:, ::-1, ...] + + if self.depth: + if np.random.random() < self.p: + x = x[..., ::-1] + + return x + + diff --git a/dataloaders/niftiio.py b/dataloaders/niftiio.py new file mode 100644 index 0000000000000000000000000000000000000000..b5764dd3b073d039313553e85fb1d1c86905e913 --- /dev/null +++ b/dataloaders/niftiio.py @@ -0,0 +1,48 @@ +""" +Utils for datasets +""" +import numpy as np + +import numpy as np +import SimpleITK as sitk + + +def read_nii_bysitk(input_fid, peel_info = False): + """ read nii to numpy through simpleitk + peelinfo: taking direction, origin, spacing and metadata out + """ + img_obj = sitk.ReadImage(input_fid) + img_np = sitk.GetArrayFromImage(img_obj) + if peel_info: + info_obj = { + "spacing": img_obj.GetSpacing(), + "origin": img_obj.GetOrigin(), + "direction": img_obj.GetDirection(), + "array_size": img_np.shape + } + return img_np, info_obj + else: + return img_np + +def convert_to_sitk(input_mat, peeled_info): + """ + write a numpy array to sitk image object with essential meta-data + """ + nii_obj = sitk.GetImageFromArray(input_mat) + if peeled_info: + nii_obj.SetSpacing( peeled_info["spacing"] ) + nii_obj.SetOrigin( peeled_info["origin"] ) + nii_obj.SetDirection(peeled_info["direction"] ) + return nii_obj + +def np2itk(img, ref_obj): + """ + img: numpy array + ref_obj: reference sitk object for copying information from + """ + itk_obj = sitk.GetImageFromArray(img) + itk_obj.SetSpacing( ref_obj.GetSpacing() ) + itk_obj.SetOrigin( ref_obj.GetOrigin() ) + itk_obj.SetDirection( ref_obj.GetDirection() ) + return itk_obj + diff --git a/models/ProtoMedSAM.py b/models/ProtoMedSAM.py new file mode 100644 index 0000000000000000000000000000000000000000..26ce75d585aaa2d563324531213380b6cd5687ff --- /dev/null +++ b/models/ProtoMedSAM.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from models.ProtoSAM import ModelWrapper +from segment_anything import sam_model_registry +from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components + +class ProtoMedSAM(nn.Module): + def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False, coarse_pred_only=False): + super().__init__() + if isinstance(image_size, int): + image_size = (image_size, image_size) + self.image_size = image_size + self.coarse_segmentation_model = coarse_segmentation_model + self.get_sam(sam_pretrained_path) + self.coarse_pred_only = coarse_pred_only + self.debug = debug + self.use_cca = use_cca + + + def get_sam(self, checkpoint_path): + model_type="vit_b" # TODO make generic? + if 'vit_h' in checkpoint_path: + model_type = "vit_h" + self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval() + + + torch.no_grad() + def medsam_inference(self, img_embed, box_1024, H, W, query_label=None): + box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) + if len(box_torch.shape) == 2: + box_torch = box_torch[:, None, :] # (B, 1, 4) + + sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder( + points=None, + boxes=box_torch, + masks=None, + ) + low_res_logits, conf = self.medsam.mask_decoder( + image_embeddings=img_embed, # (B, 256, 64, 64) + image_pe=self.medsam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) + sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) + dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) + multimask_output=True if query_label is not None else False, + ) + + low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) + + low_res_pred = F.interpolate( + low_res_pred, + size=(H, W), + mode="bilinear", + align_corners=False, + ) # (1, 1, gt.shape) + low_res_pred = low_res_pred.squeeze().cpu() # (256, 256) + + low_res_pred = low_res_pred.numpy() + medsam_seg = (low_res_pred > 0.5).astype(np.uint8) + + if query_label is not None: + medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :] + + return medsam_seg, conf.cpu().detach().numpy() + + def get_iou(self, pred, label): + """ + pred np array shape h,w type uint8 + label np array shpae h,w type uiint8 + """ + tp = np.logical_and(pred, label).sum() + fp = np.logical_and(pred, 1-label).sum() + fn = np.logical_and(1-pred, label).sum() + iou = tp / (tp + fp + fn) + return iou + + def get_best_mask(self, masks, labels): + """ + masks np shape ( B, h, w) + labels torch shape (1, H, W) + """ + np_labels = labels[0].clone().detach().cpu().numpy() + best_iou, best_mask = 0, None + for mask in masks: + iou = self.get_iou(mask, np_labels) + if iou > best_iou: + best_iou = iou + best_mask = mask + + return best_mask + + def get_bbox(self, pred): + """ + pred is tensor of shape (H,W) - 1 is fg, 0 is bg. + return bbox of pred s.t np.array([xmin, y_min, xmax, ymax]) + """ + if isinstance(pred, np.ndarray): + pred = torch.from_numpy(pred) + if pred.max() == 0: + return None + indices = torch.nonzero(pred) + ymin, xmin = indices.min(dim=0)[0] + ymax, xmax = indices.max(dim=0)[0] + return np.array([xmin, ymin, xmax, ymax]) + + + def get_bbox_per_cc(self, conn_components): + """ + conn_components: output of cca function + return list of bboxes per connected component, each bbox is a list of 2d points + """ + bboxes = [] + for i in range(1, conn_components[0]): + # get the indices of the foreground points + pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8) + bboxes.append(self.get_bbox(pred)) + + bboxes = np.array(bboxes) + return bboxes + + def forward(self, query_image, coarse_model_input, degrees_rotate=0): + """ + query_image: 3d tensor of shape (1, 3, H, W) + images should be normalized with mean and std but not to [0, 1]? + """ + original_size = query_image.shape[-2] + # rotate query_image by degrees_rotate + rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate) + # print(f"rotating query image took {time.time() - start_time} seconds") + coarse_model_input.set_query_images(rotated_img) + output_logits_rot = self.coarse_segmentation_model(coarse_model_input) + # print(f"ALPNet took {time.time() - start_time} seconds") + + if degrees_rotate != 0: + output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate) + # print(f"reversing rotated output_logits took {time.time() - start_time} seconds") + else: + output_logits = output_logits_rot + + # check if softmax is needed + # output_p = output_logits.softmax(dim=1) + output_p = output_logits + pred = output_logits.argmax(dim=1)[0] + if self.debug: + _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu()) + plt.subplot(132) + plt.imshow(query_image[0,0].detach().cpu()) + plt.imshow(_pred, alpha=0.5) + plt.subplot(131) + # plot heatmap of prob of being fg + plt.imshow(output_p[0, 1].detach().cpu()) + # plot rotated query image and rotated pred + output_p_rot = output_logits_rot.softmax(dim=1) + _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu()) + _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0] + plt.subplot(133) + plt.imshow(rotated_img[0, 0].detach().cpu()) + plt.imshow(_pred_rot, alpha=0.5) + plt.savefig('debug/coarse_pred.png') + plt.close() + + if self.coarse_pred_only: + output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits + pred = output_logits.argmax(dim=1)[0] + conf = get_confidence_from_logits(output_logits) + if self.use_cca: + _pred = np.array(pred.detach().cpu()) + _pred, conf = cca(_pred, output_logits, return_conf=True) + pred = torch.from_numpy(_pred) + if self.training: + return output_logits, [conf] + return pred, [conf] + + if query_image.shape[-2:] != self.image_size: + query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear') + output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear') + if need_softmax(output_logits): + output_logits = output_logits.softmax(dim=1) + + output_p = output_logits + pred = output_p.argmax(dim=1)[0] + + _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) + if self.use_cca: + conn_components = cca(_pred, output_logits, return_cc=True) + conf=None + else: + conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True) + if self.debug: + plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf) + # print(f"connected components took {time.time() - start_time} seconds") + + if _pred.max() == 0: + if output_p.shape[-2:] != original_size: + output_p = F.interpolate(output_p, size=original_size, mode='bilinear') + return output_p.argmax(dim=1)[0], [0] + + H, W = query_image.shape[-2:] + # bbox = self.get_bbox(_pred) + bbox = self.get_bbox_per_cc(conn_components) + bbox = bbox / np.array([W, H, W, H]) * max(self.image_size) + query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min()) + with torch.no_grad(): + image_embedding = self.medsam.image_encoder(query_image) + + medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W) + + if self.debug: + fig, ax = plt.subplots(1, 2) + ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu()) + show_mask(medsam_seg, ax[0]) + ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu()) + show_box(bbox[0], ax[1]) + plt.savefig('debug/medsam_pred.png') + plt.close() + + medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device) + if medsam_seg.shape[-2:] != original_size: + medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0] + + return medsam_seg, [conf] + + def segment_all(self, query_image, query_label): + H, W = query_image.shape[-2:] + # bbox = self.get_bbox(_pred) + # bbox = self.get_bbox_per_cc(conn_components) + # bbox = bbox / np.array([W, H, W, H]) * max(self.image_size) + bbox = np.array([[0, 0, W, H]]) + query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min()) + with torch.no_grad(): + image_embedding = self.medsam.image_encoder(query_image) + + medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label) + + if self.debug: + fig, ax = plt.subplots(1, 2) + ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu()) + show_mask(medsam_seg, ax[0]) + ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu()) + show_box(bbox[0], ax[1]) + plt.savefig('debug/medsam_pred.png') + plt.close() + + medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device) + if medsam_seg.shape[-2:] != (H, W): + medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0] + + return medsam_seg.view(H,W), [conf] + + +def show_mask(mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch( + plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2) + ) diff --git a/models/ProtoSAM.py b/models/ProtoSAM.py new file mode 100644 index 0000000000000000000000000000000000000000..6617ab96774857c486f0ace352eb963931ef78e7 --- /dev/null +++ b/models/ProtoSAM.py @@ -0,0 +1,708 @@ +import warnings +import torch +import torch.nn as nn +from torch.nn import functional as F +import matplotlib.pyplot as plt +import numpy as np +from models.grid_proto_fewshot import FewShotSeg +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor +from models.SamWrapper import SamWrapper +from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits +from util.lora import inject_trainable_lora +from models.segment_anything.utils.transforms import ResizeLongestSide +import cv2 +import time +from abc import ABC, abstractmethod + +CONF_MODE="conf" +CENTROID_MODE="centroid" +BOTH_MODE="both" +POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE) + +TYPE_ALPNET="alpnet" +TYPE_SAM="sam" + +def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"): + num_labels, labels, stats, centroids = cca_output + # Create an output image with random colors for each component + output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8) + for label in range(1, num_labels): # Start from 1 to skip the background + mask = labels == label + output_image[mask] = np.random.randint(0, 255, size=3) + + # Plotting the original and the colored components image + plt.figure(figsize=(10, 5)) + plt.subplot(121), plt.imshow(original_image), plt.title('Original Image') + plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components') + if confidences is not None: + # Plot the axes color chart with the confidences, use the same colors as the connected components + plt.subplot(122) + scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet') + plt.colorbar(scatter) + + plt.savefig(title) + plt.close() + +class SegmentationInput(ABC): + @abstractmethod + def set_query_images(self, query_images): + pass + + def to(self, device): + pass + +class SegmentationOutput(ABC): + @abstractmethod + def get_prediction(self): + pass + +class ALPNetInput(SegmentationInput): # for alpnet + def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None): + self.supp_imgs = [support_images] + self.fore_mask = [support_labels] + self.back_mask = [[1 - sup_labels for sup_labels in support_labels]] + self.qry_imgs = [query_images] + self.isval = isval + self.val_wsize = val_wsize + self.show_viz = show_viz + self.supp_fts = supp_fts + + def set_query_images(self, query_images): + self.qry_imgs = [query_images] + + def to(self, device): + self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]] + self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]] + self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]] + self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs] + if self.supp_fts is not None: + self.supp_fts = self.supp_fts.to(device) + +class ALPNetOutput(SegmentationOutput): + def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts): + self.pred = pred + self.align_loss = align_loss + self.sim_maps = sim_maps + self.assign_maps = assign_maps + self.proto_grid = proto_grid + self.supp_fts = supp_fts + self.qry_fts = qry_fts + + def get_prediction(self): + return self.pred + +class SAMWrapperInput(SegmentationInput): + def __init__(self, image, image_labels): + self.image = image + self.image_labels = image_labels + + def set_query_images(self, query_images): + B, C, H, W = query_images.shape + if isinstance(query_images, torch.Tensor): + query_images = query_images.cpu().detach().numpy() + assert B == 1, "batch size must be 1" + query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255 + query_images = query_images.astype(np.uint8) + self.image = np.transpose(query_images[0], (1, 2, 0)) + + def to(self, device): + pass + + +class InputFactory(ABC): + @staticmethod + def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None): + + if input_type == TYPE_ALPNET: + return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts) + elif input_type == TYPE_SAM: + qimg = np.array(query_image.detach().cpu()) + B,C,H,W = qimg.shape + assert B == 1, "batch size must be 1" + gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W) + assert np.unique(gts).shape[0] <= 2, "support labels must be binary" + gts[gts > 0] = 1 + qimg = qimg.reshape(H,W,C) + qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255 + qimg = qimg.astype(np.uint8) + return SAMWrapperInput(qimg, gts) + else: + raise ValueError(f"input_type not supported") + + +class ModelWrapper(ABC): + def __init__(self, model): + self.model = model + + def __call__(self, input_data: SegmentationInput)->SegmentationOutput: + pass + + def state_dict(self): + return self.model.state_dict() + + def load_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) + + def eval(self): + self.model.eval() + + def train(self): + self.model.train() + + def parameters(self): + pass + +class ALPNetWrapper(ModelWrapper): + def __init__(self, model: FewShotSeg): + super().__init__(model) + + def __call__(self, input_data: ALPNetInput): + output = self.model(**input_data.__dict__) + output = ALPNetOutput(*output) + return output.pred + + def parameters(self): + return self.model.encoder.parameters() + + def train(self): + self.model.encoder.train() + +class SamWrapperWrapper(ModelWrapper): + def __init__(self, model:SamWrapper): + super().__init__(model) + + def __call__(self, input_data: SAMWrapperInput): + pred = self.model(**input_data.__dict__) + # make pred look like logits + pred = torch.tensor(pred).float()[None, None, ...] + pred = torch.cat([1-pred, pred], dim=1) + return pred + + def to(self, device): + self.model.sam.to(device) + +class ProtoSAM(nn.Module): + def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ): + super().__init__() + if isinstance(image_size, int): + image_size = (image_size, image_size) + self.image_size = image_size + self.coarse_segmentation_model = coarse_segmentation_model + self.get_sam(sam_pretrained_path, use_sam_trans) + self.num_points_for_sam = num_points_for_sam + self.use_points = use_points + self.use_bbox = use_bbox # if False then uses points + self.use_mask = use_mask + self.use_neg_points = use_neg_points + assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask" + self.use_cca = use_cca + self.point_mode = point_mode + if self.point_mode not in POINT_MODES: + raise ValueError(f"point mode must be one of {POINT_MODES}") + self.debug=debug + self.coarse_pred_only = coarse_pred_only + + def get_sam(self, checkpoint_path, use_sam_trans): + model_type="vit_b" # TODO make generic? + if 'vit_h' in checkpoint_path: + model_type = "vit_h" + self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval() + self.predictor = SamPredictor(self.sam) + self.sam.requires_grad_(False) + if use_sam_trans: + # sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1]) + sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size) + sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1) + sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1) + else: + sam_trans = None + + self.sam_trans = sam_trans + + def get_bbox(self, pred): + ''' + pred tensor of shape (H, W) where 1 represents foreground and 0 represents background + returns a list of 2d points representing the bbox + ''' + if isinstance(pred, np.ndarray): + pred = torch.tensor(pred) + # get the indices of the foreground points + indices = torch.nonzero(pred) + # get the min and max of the indices + min_x = indices[:, 1].min() + max_x = indices[:, 1].max() + min_y = indices[:, 0].min() + max_y = indices[:, 0].max() + # get the bbox + bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] + + + return bbox + + def get_bbox_per_cc(self, conn_components): + """ + conn_components: output of cca function + return list of bboxes per connected component, each bbox is a list of 2d points + """ + bboxes = [] + for i in range(1, conn_components[0]): + # get the indices of the foreground points + indices = torch.nonzero(torch.tensor(conn_components[1] == i)) + # get the min and max of the indices + min_x = indices[:, 1].min() + max_x = indices[:, 1].max() + min_y = indices[:, 0].min() + max_y = indices[:, 0].max() + # get the bbox + # bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] + # bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]] + # bbox should be in a XYXY format + bbox = [min_x, min_y, max_x, max_y] + bboxes.append(bbox) + + bboxes = np.array(bboxes) + return bboxes + + def get_most_conf_points(self, output_p_fg, pred, k): + ''' + get the k most confident points from pred + output_p: 3d tensor of shape (H, W) + pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background + ''' + # Create a mask where pred is 1 + mask = pred.bool() + + # Apply the mask to output_p_fg + masked_output_p_fg = output_p_fg[mask] + if masked_output_p_fg.numel() == 0: + return None, None + # Get the top k probabilities and their indices + confidences, indices = torch.topk(masked_output_p_fg, k) + + # Get the locations of the top k points in xy format + locations = torch.nonzero(mask)[indices] + # convert locations to xy format + locations = locations[:, [1, 0]] + # convert locations to list of lists + # points = [loc.tolist() for loc in locations] + + return locations.numpy(), [float(conf.item()) for conf in confidences] + + + def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None): + ''' + points: np array of shape (N, 2) where each row is a point in xy format + pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background + image: 2d tensor of shape (H,W) representing the image + bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format + ''' + warnings.filterwarnings('ignore', category=UserWarning) + if isinstance(pred, torch.Tensor): + pred = pred.cpu().detach().numpy() + if len(image.shape) == 3 and image.shape[0] == 3: + image = image.permute(1, 2, 0) + if title is None: + title="debug/most_conf_points.png" + + fig = plt.figure() + image = (image - image.min()) / (image.max() - image.min()) + plt.imshow(image) + plt.imshow(pred, alpha=0.5) + for i, point in enumerate(points): + plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red') + if confidences is not None: + plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red') + # assume points is a list of lists + if bboxes is not None: + for bbox in bboxes: + if bbox is None: + continue + bbox = np.array(bbox) + # plt.scatter(bbox[:, 1], bbox[:, 0], c='red') + # plot a line connecting the points + box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]) + box = np.vstack([box, box[0]]) + plt.plot(box[:, 0], box[:, 1], c='green') + plt.colorbar() + fig.savefig(title) + plt.close(fig) + + def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None): + if len(image.shape) == 3: + image = image.permute(1, 2, 0) + image = (image - image.min()) / (image.max() - image.min()) + for i, (mask, score) in enumerate(zip(masks, scores)): + plt.figure(figsize=(10,10)) + plt.imshow(image) + show_mask(mask, plt.gca()) + if input_point is not None: + show_points(input_point, input_label, plt.gca()) + if input_box is not None: + show_box(input_box, plt.gca()) + plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) + # plt.axis('off') + plt.savefig(f'debug/sam_mask_{i+1}.png') + plt.close() + if i > 5: + break + + def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1): + """ + args: + conn_components: output of cca function + output_p: 3d tensor of shape (1, 2, H, W) + get_neg_points: bool, if True then return the negative points + l: int, number of negative points to get + """ + sam_input_points = [] + sam_neg_points = [] + fg_p = output_p[0, 1].detach().cpu() + + if get_neg_points: + # get global negative points + bg_p = output_p[0, 0].detach().cpu() + bg_p[bg_p < 0.95] = 0 + bg_pred = torch.where(bg_p > 0, 1, 0) + glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1) + if self.debug: + # plot the bg_p as a heatmap + plt.figure() + plt.imshow(bg_p) + plt.colorbar() + plt.savefig('debug/bg_p_heatmap.png') + plt.close() + + for i, cc_id in enumerate(np.unique(conn_components[1])): + # get self.num_points_for_sam most confident points from pred + if cc_id == 0: + continue # skip background + pred = torch.tensor(conn_components[1] == cc_id).float() + + if self.point_mode == CONF_MODE: + points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2) + elif self.point_mode == CENTROID_MODE: + points = conn_components[3][cc_id][None, :] # (1, 2) + confidences = [1 for _ in range(len(points))] + elif self.point_mode == BOTH_MODE: + points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) + point = conn_components[3][cc_id][None, :] + points = np.vstack([points, point]) # (N+1, 2) + confidences.append(1) + else: + raise NotImplementedError(f"point mode {self.point_mode} not implemented") + sam_input_points.append(np.array(points)) + + if get_neg_points: + pred_uint8 = (pred.numpy() * 255).astype(np.uint8) + + # Dilate the mask to expand it + kernel_size = 3 # Size of the dilation kernel, adjust accordingly + kernel = np.ones((kernel_size, kernel_size), np.uint8) + dilation_iterations = 10 # Number of times dilation is applied, adjust as needed + dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations) + + # Subtract the original mask from the dilated mask + # This will give a boundary that is only outside the original mask + outside_boundary = dilated_mask - pred_uint8 + + # Convert back to torch tensor and normalize + boundary = torch.tensor(outside_boundary).float() / 255 + try: + bg_p = output_p[0, 0].detach().cpu() + neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l) + except RuntimeError as e: + # make each point (None, None) + neg_points = None + # append global negative points to the negative points + if neg_points is not None and glob_neg_points is not None: + neg_points = np.vstack([neg_points, glob_neg_points]) + else: + neg_points = glob_neg_points if neg_points is None else neg_points + if self.debug and neg_points is not None: + # draw an image with 2 subplots, one is the pred and the other is the boundary + plt.figure() + plt.subplot(121) + plt.imshow(pred) + plt.imshow(boundary, alpha=0.5) + # plot the neg points + plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') + plt.subplot(122) + plt.imshow(pred) + plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') + plt.savefig('debug/pred_and_boundary.png') + plt.close() + sam_neg_points.append(neg_points) + else: + # create a list of None same shape as points + sam_neg_points = [None for _ in range(len(sam_input_points))] + + sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))]) + sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2) + # if get_neg_points: + sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None + if sam_neg_input_points is not None: + sam_neg_input_points = sam_neg_points + sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) ) + else: + sam_neg_input_points = None + sam_neg_input_labels = None + + return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels + + def get_sam_input_mask(self, conn_components): + sam_input_masks = [] + sam_input_mask_lables = [] + for i, cc_id in enumerate(np.unique(conn_components[1])): + # get self.num_points_for_sam most confident points from pred + if cc_id == 0: + continue + pred = torch.tensor(conn_components[1] == cc_id).float() + sam_input_masks.append(pred) + sam_input_mask_lables.append(cc_id) + + sam_input_masks = np.stack(sam_input_masks) + sam_input_mask_lables = np.array(sam_input_mask_lables) + + return sam_input_masks, sam_input_mask_lables + + def predict_w_masks(self, sam_input_masks, qry_img, original_size): + masks = [] + scores = [] + for in_mask in sam_input_masks: + in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST) + in_mask[in_mask == 1] = 10 + in_mask[in_mask == 0] = -8 + assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 + self.predictor.set_image(qry_img) + mask, score, _ = self.predictor.predict( + mask_input=in_mask[None, ...].astype(np.uint8), + multimask_output=True) + # get max index from score + if self.debug: + # plot each channel of mask + fig, ax = plt.subplots(1, 4, figsize=(15, 5)) + for i in range(mask.shape[0]): + ax[i].imshow(qry_img) + ax[i].imshow(mask[i], alpha=0.5) + ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18) + # ax[i].axis('off') + ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST)) + fig.savefig(f'debug/sam_mask_from_mask_prompts.png') + plt.close(fig) + + + max_index = score.argmax() + masks.append(mask[max_index]) + scores.append(score[max_index]) + + return masks, scores + + def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False): + masks, scores = [], [] + self.predictor.set_image(qry_img) + # if sam_input_points is None: + # sam_input_points = [None for _ in range(len(bboxes))] + for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points): + assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 + points = point + point_labels = np.array([1] * len(point)) if point is not None else None + if self.use_neg_points: + neg_points = [npoint for npoint in neg_point if None not in npoint] + points = np.vstack([point, *neg_points]) + point_labels = np.array([1] * len(point) + [0] * len(neg_points)) + if self.debug: + self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points + mask, score, _ = self.predictor.predict( + point_coords=points, + point_labels=point_labels, + # box=bbox_xyxy[None, :] if bbox_xyxy is not None else None, + box = bbox_xyxy if bbox_xyxy is not None else None, + # mask_input=sam_mask_input, + return_logits=return_logits, + multimask_output=False if self.use_cca else True + ) + # best_pred_idx = np.argmax(score) + best_pred_idx = 0 + masks.append(mask[best_pred_idx]) + scores.append(score[best_pred_idx]) + + if self.debug: + # pass + self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None) + + return masks, scores + + + def forward(self, query_image, coarse_model_input, degrees_rotate=0): + """ + query_image: 3d tensor of shape (1, 3, H, W) + images should be normalized with mean and std but not to [0, 1]? + """ + original_size = query_image.shape[-2] + # rotate query_image by degrees_rotate + start_time = time.time() + rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate) + # print(f"rotating query image took {time.time() - start_time} seconds") + start_time = time.time() + coarse_model_input.set_query_images(rotated_img) + output_logits_rot = self.coarse_segmentation_model(coarse_model_input) + # print(f"ALPNet took {time.time() - start_time} seconds") + + if degrees_rotate != 0: + start_time = time.time() + output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate) + # print(f"reversing rotated output_logits took {time.time() - start_time} seconds") + else: + output_logits = output_logits_rot + + # check if softmax is needed + output_p = output_logits.softmax(dim=1) + # output_p = output_logits + pred = output_logits.argmax(dim=1)[0] + if self.debug: + _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu()) + plt.subplot(132) + plt.imshow(query_image[0,0].detach().cpu()) + plt.imshow(_pred, alpha=0.5) + plt.subplot(131) + # plot heatmap of prob of being fg + plt.imshow(output_p[0, 1].detach().cpu()) + # plot rotated query image and rotated pred + output_p_rot = output_logits_rot.softmax(dim=1) + _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu()) + _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0] + plt.subplot(133) + plt.imshow(rotated_img[0, 0].detach().cpu()) + plt.imshow(_pred_rot, alpha=0.5) + plt.savefig('debug/coarse_pred.png') + plt.close() + + if self.coarse_pred_only: + output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits + pred = output_logits.argmax(dim=1)[0] + conf = get_confidence_from_logits(output_logits) + if self.use_cca: + _pred = np.array(pred.detach().cpu()) + _pred, conf = cca(_pred, output_logits, return_conf=True) + pred = torch.from_numpy(_pred) + if self.training: + return output_logits, [conf] + # Ensure pred is a float tensor for consistent visualization + return pred.float(), [conf] + + if query_image.shape[-2:] != self.image_size: + query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear') + output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear') + # if need_softmax(output_logits): + # output_logits = output_logits.softmax(dim=1) + + # output_p = output_logits + output_p = output_logits.softmax(dim=1) + pred = output_p.argmax(dim=1)[0] + + _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) + start_time = time.time() + if self.use_cca: + conn_components = cca(_pred, output_logits, return_cc=True) + conf=None + else: + conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True) + if self.debug: + plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf) + # print(f"connected components took {time.time() - start_time} seconds") + if _pred.max() == 0: + return output_p.argmax(dim=1)[0], [0] + + # get bbox from pred + if self.use_bbox: + start_time = time.time() + try: + bboxes = self.get_bbox_per_cc(conn_components) + except: + bboxes = [None] * conn_components[0] + else: + bboxes = [None] * conn_components[0] + # print(f"getting bboxes took {time.time() - start_time} seconds") + + + start_time = time.time() + if self.use_points: + sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1) + else: + sam_input_points = [None] * conn_components[0] + sam_input_point_labels = [None] * conn_components[0] + sam_neg_input_points = [None] * conn_components[0] + sam_neg_input_labels = [None] * conn_components[0] + # print(f"getting sam input points took {time.time() - start_time} seconds") + + if self.use_mask: + sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components) + else: + sam_input_masks = None + sam_input_mask_labels = None + + if self.debug and sam_input_points is not None: + title = f'debug/most_conf_points.png' + if self.use_cca: + title = f'debug/most_conf_points_cca.png' + # convert points to a list where each item is a list of 2 elements in xy format + self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points + + # self.sam_trans = None + if self.sam_trans is None: + query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() + else: + query_image = self.sam_trans.apply_image_torch(query_image[0]) + query_image = self.sam_trans.preprocess(query_image) + query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() + # mask = self.sam_trans.preprocess(mask) + + + query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8) + if self.use_mask: + masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size) + + start_time = time.time() + if self.use_points or self.use_bbox: + masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False) + # print(f"predicting w points/bbox took {time.time() - start_time} seconds") + + pred = sum(masks) + if not self.training: + pred = pred > 0 + pred = torch.tensor(pred).float().to(output_p.device) + + # pred = torch.tensor(masks[0]).float().cuda() + # resize pred to the size of the input + pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0] + + return pred, scores + + +def show_mask(mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + +def need_softmax(tensor, dim=1): + return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0)) + + + + + \ No newline at end of file diff --git a/models/SamWrapper.py b/models/SamWrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e200c9d72b3df8b487f7137f3c5ac8f8e2239651 --- /dev/null +++ b/models/SamWrapper.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import numpy as np +from models.segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator +from models.segment_anything.utils.transforms import ResizeLongestSide +import cv2 + +def get_iou(mask, label): + tp = (mask * label).sum() + fp = (mask * (1-label)).sum() + fn = ((1-mask) * label).sum() + iou = tp / (tp + fp + fn) + return iou + +class SamWrapper(nn.Module): + def __init__(self,sam_args): + """ + sam_args: dict should include the following + { + "model_type": "vit_h", + "sam_checkpoint": "path to checkpoint" pretrained_model/sam_vit_h.pth + } + """ + super().__init__() + self.sam = sam_model_registry[sam_args['model_type']](checkpoint=sam_args['sam_checkpoint']) + self.mask_generator = SamAutomaticMaskGenerator(self.sam) + self.transform = ResizeLongestSide(self.sam.image_encoder.img_size) + + def forward(self, image, image_labels): + """ + generate masks for a batch of images + return mask that has the largest iou with the image label + Args: + images (np.ndarray): The image to generate masks for, in HWC uint8 format. + image_labels (np.ndarray): The image labels to generate masks for, in HWC uint8 format. assuming binary labels + """ + image = self.transform.apply_image(image) + masks = self.mask_generator.generate(image) + + best_index, best_iou = None, 0 + for i, mask in enumerate(masks): + segmentation = mask['segmentation'] + iou = get_iou(segmentation.astype(np.uint8), image_labels) + if best_index is None or iou > best_iou: + best_index = i + best_iou = iou + + return masks[best_index]['segmentation'] + + def to(self, device): + self.sam.to(device) + self.mask_generator.to(device) + self.mask_generator.predictor.to(device) + + + +if __name__ == "__main__": + sam_args = { + "model_type": "vit_h", + "sam_checkpoint": "pretrained_model/sam_vit_h.pth" + } + sam_wrapper = SamWrapper(sam_args).cuda() + image = cv2.imread("./Kheops-Pyramid.jpg") + image = np.array(image).astype('uint8') + image_labels = torch.rand(1,3,224,224) + sam_wrapper(image, image_labels) + + \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/__pycache__/ProtoSAM.cpython-312.pyc b/models/__pycache__/ProtoSAM.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c016076ac26fe6c2629c15cf243768e0808b0a4e Binary files /dev/null and b/models/__pycache__/ProtoSAM.cpython-312.pyc differ diff --git a/models/__pycache__/SamWrapper.cpython-312.pyc b/models/__pycache__/SamWrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a0f90638a9b23418f2f77b12f9bc4fcc3229f9f Binary files /dev/null and b/models/__pycache__/SamWrapper.cpython-312.pyc differ diff --git a/models/__pycache__/__init__.cpython-312.pyc b/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..477d2c8b5e4cd27e2f7a6b0f4197712c155126db Binary files /dev/null and b/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/__pycache__/alpmodule.cpython-312.pyc b/models/__pycache__/alpmodule.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8deaf542b6c4429a6fe9545355868213ed10bbba Binary files /dev/null and b/models/__pycache__/alpmodule.cpython-312.pyc differ diff --git a/models/__pycache__/grid_proto_fewshot.cpython-312.pyc b/models/__pycache__/grid_proto_fewshot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fbd9a441fb3a8c97ade4cfdd32433b0b31d6b64 Binary files /dev/null and b/models/__pycache__/grid_proto_fewshot.cpython-312.pyc differ diff --git a/models/alpmodule.py b/models/alpmodule.py new file mode 100644 index 0000000000000000000000000000000000000000..5572bd4eebe40f663e6aee281b2583690cd442d5 --- /dev/null +++ b/models/alpmodule.py @@ -0,0 +1,199 @@ +""" +ALPModule +""" +import torch +import time +import math +from torch import nn +from torch.nn import functional as F +import numpy as np +from pdb import set_trace +import matplotlib.pyplot as plt +# for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm + +def safe_norm(x, p = 2, dim = 1, eps = 1e-4): + x_norm = torch.norm(x, p = p, dim = dim) # .detach() + x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps) + x = x.div(x_norm.unsqueeze(1).expand_as(x)) + return x + + +class MultiProtoAsConv(nn.Module): + def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'): + """ + ALPModule + Args: + proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2 + feature_hw: Spatial size of input feature map + + """ + super(MultiProtoAsConv, self).__init__() + self.feature_hw = feature_hw + self.proto_grid = proto_grid + self.upsample_mode = upsample_mode + kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ] + self.kernel_size = kernel_size + print(f"MultiProtoAsConv: kernel_size: {kernel_size}") + self.avg_pool_op = nn.AvgPool2d( kernel_size ) + + if use_attention: + self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True) + self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True) + self.fg_mask_projection = nn.Sequential( + nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True), + ) + self.bg_mask_projection = nn.Sequential( + nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True), + ) + + def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ): + if mode == 'mask': + pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w] + # incase there are more than one prototypes in the same location, take the max + pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0) + vis_dict = {'proto_assign': pred_mask} # things to visualize + if vis_sim: + vis_dict['raw_local_sims'] = pred_mask + return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w] + + elif mode == 'gridconv': + dists = F.conv2d(query, prototypes[..., None, None]) * 20 + + pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) + debug_assign = dists.argmax(dim = 1).float().detach() + + vis_dict = {'proto_assign': debug_assign} # things to visualize + + if vis_sim: # return the similarity for visualization + vis_dict['raw_local_sims'] = dists.clone().detach() + return pred_grid, [debug_assign], vis_dict + + elif mode == 'gridconv+': + dists = F.conv2d(query, prototypes[..., None, None]) * 20 + + pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) + # raw_local_sims = dists.det ach() + + debug_assign = dists.argmax(dim = 1).float() + + vis_dict = {'proto_assign': debug_assign} + if vis_sim: + vis_dict['raw_local_sims'] = dists.clone().detach() + + return pred_grid, [debug_assign], vis_dict + + else: + raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.") + + + def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False): + if mode == 'mask': + proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ + / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C + + pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything + pro_n = proto + proto_grid = sup_y.clone().detach() # a single prototype for the whole image + resized_proto_grid = proto_grid + non_zero = torch.nonzero(proto_grid) + + elif mode == 'gridconv': + nch = sup_x.shape[1] + + sup_nshot = sup_x.shape[0] + # if len(sup_x.shape) > 4: + # sup_x = sup_x.squeeze() + n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) + n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc + n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) + + sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) + + # get a grid of prototypes + proto_grid = sup_y_g.clone().detach() + proto_grid[proto_grid < thresh] = 0 + # interpolate the grid to the original size + non_zero = torch.nonzero(proto_grid) + + resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize]) + for index in non_zero: + resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]] + + sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) + protos = n_sup_x[sup_y_g > thresh, :] # npro, nc + pro_n = safe_norm(protos) + + elif mode == 'gridconv+': + nch = sup_x.shape[1] + n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) + sup_nshot = sup_x.shape[0] + n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) + n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) + sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) + + # get a grid of prototypes + proto_grid = sup_y_g.clone().detach() + proto_grid[proto_grid < thresh] = 0 + non_zero = torch.nonzero(proto_grid) + for i, idx in enumerate(non_zero): + proto_grid[0, idx[1], idx[2], idx[3]] = i + 1 + resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize]) + for index in non_zero: + resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]] + + sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) + protos = n_sup_x[sup_y_g > thresh, :] + + glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ + / (sup_y.sum(dim=(-1, -2)) + 1e-5) + + pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 )) + return pro_n, resized_proto_grid, non_zero + + def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs): + """ + Now supports + Args: + mode: 'mask'/ 'grid'. if mask, works as original prototyping + qry: [way(1), nc, h, w] + sup_x: [nb, nc, h, w] + sup_y: [nb, 1, h, w] + vis_sim: visualize raw similarities or not + New + mode: 'mask'/ 'grid'. if mask, works as original prototyping + qry: [way(1), nb(1), nc, h, w] + sup_x: [way(1), shot, nb(1), nc, h, w] + sup_y: [way(1), shot, nb(1), h, w] + vis_sim: visualize raw similarities or not + """ + + qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w] + sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w] + sup_y = sup_y.squeeze(0) # [nshot, 1, h, w] + + def safe_norm(x, p = 2, dim = 1, eps = 1e-4): + x_norm = torch.norm(x, p = p, dim = dim) # .detach() + x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps) + x = x.div(x_norm.unsqueeze(1).expand_as(x)) + return x + if val_wsize is None: + val_wsize = self.avg_pool_op.kernel_size + if isinstance(val_wsize, (tuple, list)): + val_wsize = val_wsize[0] + sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1]) + pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval) + if 0 in pro_n.shape: + print("failed to find prototypes") + qry_n = qry if mode == 'mask' else safe_norm(qry) + pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim) + + return pred_grid, debug_assign, vis_dict, proto_grid + diff --git a/models/backbone/__init__.py b/models/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/backbone/__pycache__/__init__.cpython-312.pyc b/models/backbone/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0034eb049aa1786bb4697929075ae384228280bf Binary files /dev/null and b/models/backbone/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc b/models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62a61bb0fe4b143084f791d6e93d9c1faea3cc57 Binary files /dev/null and b/models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc differ diff --git a/models/backbone/torchvision_backbones.py b/models/backbone/torchvision_backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..deec7f5717816f4aa8f92a4be11fac367c92ad41 --- /dev/null +++ b/models/backbone/torchvision_backbones.py @@ -0,0 +1,58 @@ +""" +Backbones supported by torchvison. +""" +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchvision + +class TVDeeplabRes101Encoder(nn.Module): + """ + FCN-Resnet101 backbone from torchvision deeplabv3 + No ASPP is used as we found emperically it hurts performance + """ + def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False): + super().__init__() + _model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None) + if use_coco_init: + print("###### NETWORK: Using ms-coco initialization ######") + else: + print("###### NETWORK: Training from scratch ######") + + _model_list = list(_model.children()) + self.aux_dim_keep = aux_dim_keep + self.backbone = _model_list[0] + self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension + self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False) + + _aspp = _model_list[1][0] + _conv256 = _model_list[1][1] + self.aspp_out = nn.Sequential(*[_aspp, _conv256] ) + self.use_aspp = use_aspp + + def forward(self, x_in, low_level): + """ + Args: + low_level: whether returning aggregated low-level features in FCN + """ + fts = self.backbone(x_in) + if self.use_aspp: + fts256 = self.aspp_out(fts['out']) + high_level_fts = fts256 + else: + fts2048 = fts['out'] + high_level_fts = self.localconv(fts2048) + + if low_level: + low_level_fts = fts['aux'][:, : self.aux_dim_keep] + return high_level_fts, low_level_fts + else: + return high_level_fts + + + + + diff --git a/models/grid_proto_fewshot.py b/models/grid_proto_fewshot.py new file mode 100644 index 0000000000000000000000000000000000000000..0e62bde0de706941281ef06d9c2300711462c753 --- /dev/null +++ b/models/grid_proto_fewshot.py @@ -0,0 +1,427 @@ +""" +ALPNet +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from .alpmodule import MultiProtoAsConv +from .backbone.torchvision_backbones import TVDeeplabRes101Encoder +from util.consts import DEFAULT_FEATURE_SIZE +from util.lora import inject_trainable_lora +# from util.utils import load_config_from_url, plot_dinov2_fts +import math + +# Specify a local path to the repository (or use installed package instead) +FG_PROT_MODE = 'gridconv+' # using both local and global prototype +# FG_PROT_MODE = 'mask' +# using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet) +BG_PROT_MODE = 'gridconv' + +# thresholds for deciding class of prototypes +FG_THRESH = 0.95 +BG_THRESH = 0.95 + + +class FewShotSeg(nn.Module): + """ + ALPNet + Args: + in_channels: Number of input channels + cfg: Model configurations + """ + + def __init__(self, image_size, pretrained_path=None, cfg=None): + super(FewShotSeg, self).__init__() + self.image_size = image_size + self.pretrained_path = pretrained_path + print(f'###### Pre-trained path: {self.pretrained_path} ######') + self.config = cfg or { + 'align': False, 'debug': False} + self.get_encoder() + self.get_cls() + if self.pretrained_path: + self.load_state_dict(torch.load(self.pretrained_path), strict=True) + print( + f'###### Pre-trained model f{self.pretrained_path} has been loaded ######') + + def get_encoder(self): + self.config['feature_hw'] = [DEFAULT_FEATURE_SIZE, + DEFAULT_FEATURE_SIZE] # default feature map size + if self.config['which_model'] == 'dlfcn_res101' or self.config['which_model'] == 'default': + use_coco_init = self.config['use_coco_init'] + self.encoder = TVDeeplabRes101Encoder(use_coco_init) + self.config['feature_hw'] = [ + math.ceil(self.image_size/8), math.ceil(self.image_size/8)] + elif self.config['which_model'] == 'dinov2_l14': + self.encoder = torch.hub.load( + 'facebookresearch/dinov2', 'dinov2_vitl14') + self.config['feature_hw'] = [max( + self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)] + elif self.config['which_model'] == 'dinov2_l14_reg': + try: + self.encoder = torch.hub.load( + 'facebookresearch/dinov2', 'dinov2_vitl14_reg') + except RuntimeError as e: + self.encoder = torch.hub.load( + 'facebookresearch/dino', 'dinov2_vitl14_reg', force_reload=True) + self.config['feature_hw'] = [max( + self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)] + elif self.config['which_model'] == 'dinov2_b14': + self.encoder = torch.hub.load( + 'facebookresearch/dinov2', 'dinov2_vitb14') + self.config['feature_hw'] = [max( + self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)] + else: + raise NotImplementedError( + f'Backbone network {self.config["which_model"]} not implemented') + + if self.config['lora'] > 0: + self.encoder.requires_grad_(False) + print(f'Injecting LoRA with rank:{self.config["lora"]}') + encoder_lora_params = inject_trainable_lora( + self.encoder, r=self.config['lora']) + + def get_features(self, imgs_concat): + if self.config['which_model'] == 'dlfcn_res101': + img_fts = self.encoder(imgs_concat, low_level=False) + elif 'dino' in self.config['which_model']: + # resize imgs_concat to the closest size that is divisble by 14 + imgs_concat = F.interpolate(imgs_concat, size=( + self.image_size // 14 * 14, self.image_size // 14 * 14), mode='bilinear') + dino_fts = self.encoder.forward_features(imgs_concat) + img_fts = dino_fts["x_norm_patchtokens"] # B, HW, C + img_fts = img_fts.permute(0, 2, 1) # B, C, HW + C, HW = img_fts.shape[-2:] + img_fts = img_fts.view(-1, C, int(HW**0.5), + int(HW**0.5)) # B, C, H, W + if HW < DEFAULT_FEATURE_SIZE ** 2: + img_fts = F.interpolate(img_fts, size=( + DEFAULT_FEATURE_SIZE, DEFAULT_FEATURE_SIZE), mode='bilinear') # this is if h,w < (32,32) + else: + raise NotImplementedError( + f'Backbone network {self.config["which_model"]} not implemented') + + return img_fts + + def get_cls(self): + """ + Obtain the similarity-based classifier + """ + proto_hw = self.config["proto_grid_size"] + + if self.config['cls_name'] == 'grid_proto': + embed_dim = 256 + if 'dinov2_b14' in self.config['which_model']: + embed_dim = 768 + elif 'dinov2_l14' in self.config['which_model']: + embed_dim = 1024 + self.cls_unit = MultiProtoAsConv(proto_grid=[proto_hw, proto_hw], feature_hw=self.config["feature_hw"], embed_dim=embed_dim) # when treating it as ordinary prototype + print(f"cls unit feature hw: {self.cls_unit.feature_hw}") + else: + raise NotImplementedError( + f'Classifier {self.config["cls_name"]} not implemented') + + def forward_resolutions(self, resolutions, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None): + predictions = [] + for res in resolutions: + supp_imgs_resized = [[F.interpolate(supp_img[0], size=( + res, res), mode='bilinear') for supp_img in supp_imgs]] if supp_imgs[0][0].shape[-1] != res else supp_imgs + fore_mask_resized = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[ + 0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != res else fore_mask + back_mask_resized = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[ + 0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != res else back_mask + qry_imgs_resized = [F.interpolate(qry_img, size=(res, res), mode='bilinear') + for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != res else qry_imgs + + pred = self.forward(supp_imgs_resized, fore_mask_resized, back_mask_resized, + qry_imgs_resized, isval, val_wsize, show_viz, supp_fts)[0] + predictions.append(pred) + + def resize_inputs_to_image_size(self, supp_imgs, fore_mask, back_mask, qry_imgs): + supp_imgs = [[F.interpolate(supp_img, size=( + self.image_size, self.image_size), mode='bilinear') for supp_img in supp_imgs_way] for supp_imgs_way in supp_imgs] + fore_mask = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[ + 0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != self.image_size else fore_mask + back_mask = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[ + 0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != self.image_size else back_mask + qry_imgs = [F.interpolate(qry_img, size=(self.image_size, self.image_size), mode='bilinear') + for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != self.image_size else qry_imgs + return supp_imgs, fore_mask, back_mask, qry_imgs + + def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None): + """ + Args: + supp_imgs: support images + way x shot x [B x 3 x H x W], list of lists of tensors + fore_mask: foreground masks for support images + way x shot x [B x H x W], list of lists of tensors + back_mask: background masks for support images + way x shot x [B x H x W], list of lists of tensors + qry_imgs: query images + N x [B x 3 x H x W], list of tensors + show_viz: return the visualization dictionary + """ + # ('Please go through this piece of code carefully') + # supp_imgs, fore_mask, back_mask, qry_imgs = self.resize_inputs_to_image_size( + # supp_imgs, fore_mask, back_mask, qry_imgs) + + n_ways = len(supp_imgs) + n_shots = len(supp_imgs[0]) + n_queries = len(qry_imgs) + + # NOTE: actual shot in support goes in batch dimension + assert n_ways == 1, "Multi-shot has not been implemented yet" + assert n_queries == 1 + + sup_bsize = supp_imgs[0][0].shape[0] + img_size = supp_imgs[0][0].shape[-2:] + if self.config["cls_name"] == 'grid_proto_3d': + img_size = supp_imgs[0][0].shape[-3:] + qry_bsize = qry_imgs[0].shape[0] + + imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs] + + [torch.cat(qry_imgs, dim=0),], dim=0) + + img_fts = self.get_features(imgs_concat) + if len(img_fts.shape) == 5: # for 3D + fts_size = img_fts.shape[-3:] + else: + fts_size = img_fts.shape[-2:] + if supp_fts is None: + supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view( + n_ways, n_shots, sup_bsize, -1, *fts_size) # wa x sh x b x c x h' x w' + qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view( + n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W' + else: + # N x B x C x H' x W' + qry_fts = img_fts.view(n_queries, qry_bsize, -1, *fts_size) + + fore_mask = torch.stack([torch.stack(way, dim=0) + for way in fore_mask], dim=0) # Wa x Sh x B x H' x W' + fore_mask = torch.autograd.Variable(fore_mask, requires_grad=True) + back_mask = torch.stack([torch.stack(way, dim=0) + for way in back_mask], dim=0) # Wa x Sh x B x H' x W' + + ###### Compute loss ###### + align_loss = 0 + outputs = [] + visualizes = [] # the buffer for visualization + + for epi in range(1): # batch dimension, fixed to 1 + fg_masks = [] # keep the way part + + ''' + for way in range(n_ways): + # note: index of n_ways starts from 0 + mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this + mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') + fg_masks.append( mean_sup_msk ) + + mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W] + ''' + # re-interpolate support mask to the same size as support feature + if len(fts_size) == 3: # TODO make more generic + res_fg_msk = torch.stack([F.interpolate(fore_mask[0][0].unsqueeze( + 0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw']) + res_bg_msk = torch.stack([F.interpolate(back_mask[0][0].unsqueeze( + 0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw']) + else: + res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size=fts_size, mode='nearest') + for fore_mask_w in fore_mask], dim=0) # [nway, ns, nb, nh', nw'] + res_bg_msk = torch.stack([F.interpolate(back_mask_w, size=fts_size, mode='nearest') + for back_mask_w in back_mask], dim=0) # [nway, ns, nb, nh', nw'] + + scores = [] + assign_maps = [] + bg_sim_maps = [] + fg_sim_maps = [] + bg_mode = BG_PROT_MODE + + _raw_score, _, aux_attr, _ = self.cls_unit( + qry_fts, supp_fts, res_bg_msk, mode=bg_mode, thresh=BG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz) + scores.append(_raw_score) + assign_maps.append(aux_attr['proto_assign']) + + for way, _msks in enumerate(res_fg_msk): + raw_scores = [] + for i, _msk in enumerate(_msks): + _msk = _msk.unsqueeze(0) + supp_ft = supp_fts[:, i].unsqueeze(0) + if self.config["cls_name"] == 'grid_proto_3d': # 3D + k_size = self.cls_unit.kernel_size + fg_mode = FG_PROT_MODE if F.avg_pool3d(_msk, k_size).max( + ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' # TODO figure out kernel size + else: + k_size = self.cls_unit.kernel_size + fg_mode = FG_PROT_MODE if F.avg_pool2d(_msk, k_size).max( + ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' + # TODO figure out kernel size + _raw_score, _, aux_attr, proto_grid = self.cls_unit(qry_fts, supp_ft, _msk.unsqueeze( + 0), mode=fg_mode, thresh=FG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz) + raw_scores.append(_raw_score) + + # create a score where each feature is the max of the raw_score + _raw_score = torch.stack(raw_scores, dim=1).max(dim=1)[ + 0] + scores.append(_raw_score) + assign_maps.append(aux_attr['proto_assign']) + if show_viz: + fg_sim_maps.append(aux_attr['raw_local_sims']) + # print(f"Time for fg: {time.time() - start_time}") + pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' + interpolate_mode = 'bilinear' + outputs.append(F.interpolate( + pred, size=img_size, mode=interpolate_mode)) + + ###### Prototype alignment loss ###### + if self.config['align'] and self.training: + align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi], + fore_mask[:, :, epi], back_mask[:, :, epi]) + align_loss += align_loss_epi + + output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W + grid_shape = output.shape[2:] + if self.config["cls_name"] == 'grid_proto_3d': + grid_shape = output.shape[2:] + output = output.view(-1, *grid_shape) + assign_maps = torch.stack(assign_maps, dim=1) if show_viz else None + bg_sim_maps = torch.stack(bg_sim_maps, dim=1) if show_viz else None + fg_sim_maps = torch.stack(fg_sim_maps, dim=1) if show_viz else None + + return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps, proto_grid, supp_fts, qry_fts + + + def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask): + """ + Compute the loss for the prototype alignment branch + + Args: + qry_fts: embedding features for query images + expect shape: N x C x H' x W' + pred: predicted segmentation score + expect shape: N x (1 + Wa) x H x W + supp_fts: embedding fatures for support images + expect shape: Wa x Sh x C x H' x W' + fore_mask: foreground masks for support images + expect shape: way x shot x H x W + back_mask: background masks for support images + expect shape: way x shot x H x W + """ + n_ways, n_shots = len(fore_mask), len(fore_mask[0]) + + # Masks for getting query prototype + pred_mask = pred.argmax(dim=1).unsqueeze(0) # 1 x N x H' x W' + binary_masks = [pred_mask == i for i in range(1 + n_ways)] + + # skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0] + # FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness + skip_ways = [] + + # added for matching dimensions to the new data format + qry_fts = qry_fts.unsqueeze(0).unsqueeze( + 2) # added to nway(1) and nb(1) + # end of added part + + loss = [] + for way in range(n_ways): + if way in skip_ways: + continue + # Get the query prototypes + for shot in range(n_shots): + # actual local query [way(1), nb(1, nb is now nshot), nc, h, w] + img_fts = supp_fts[way: way + 1, shot: shot + 1] + size = img_fts.shape[-2:] + mode = 'bilinear' + if self.config["cls_name"] == 'grid_proto_3d': + size = img_fts.shape[-3:] + mode = 'trilinear' + qry_pred_fg_msk = F.interpolate( + binary_masks[way + 1].float(), size=size, mode=mode) # [1 (way), n (shot), h, w] + + # background + qry_pred_bg_msk = F.interpolate( + binary_masks[0].float(), size=size, mode=mode) # 1, n, h ,w + scores = [] + + bg_mode = BG_PROT_MODE + _raw_score_bg, _, _, _ = self.cls_unit( + qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_bg_msk.unsqueeze(-3), mode=bg_mode, thresh=BG_THRESH) + + scores.append(_raw_score_bg) + if self.config["cls_name"] == 'grid_proto_3d': + fg_mode = FG_PROT_MODE if F.avg_pool3d(qry_pred_fg_msk, 4).max( + ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' + else: + fg_mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max( + ) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' + _raw_score_fg, _, _, _ = self.cls_unit( + qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_fg_msk.unsqueeze(2), mode=fg_mode, thresh=FG_THRESH) + scores.append(_raw_score_fg) + + supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' + size = fore_mask.shape[-2:] + if self.config["cls_name"] == 'grid_proto_3d': + size = fore_mask.shape[-3:] + supp_pred = F.interpolate(supp_pred, size=size, mode=mode) + + # Construct the support Ground-Truth segmentation + supp_label = torch.full_like(fore_mask[way, shot], 255, + device=img_fts.device).long() + supp_label[fore_mask[way, shot] == 1] = 1 + supp_label[back_mask[way, shot] == 1] = 0 + # Compute Loss + loss.append(F.cross_entropy( + supp_pred.float(), supp_label[None, ...], ignore_index=255) / n_shots / n_ways) + + return torch.sum(torch.stack(loss)) + + def dino_cls_loss(self, teacher_cls_tokens, student_cls_tokens): + cls_loss_weight = 0.1 + student_temp = 1 + teacher_cls_tokens = self.sinkhorn_knopp_teacher(teacher_cls_tokens) + lsm = F.log_softmax(student_cls_tokens / student_temp, dim=-1) + cls_loss = torch.sum(teacher_cls_tokens * lsm, dim=-1) + + return -cls_loss.mean() * cls_loss_weight + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp=1, n_iterations=3): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + # Q is K-by-B for consistency with notations from our paper + Q = torch.exp(teacher_output / teacher_temp).t() + # B = Q.shape[1] * world_size # number of samples to assign + B = Q.shape[1] + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def dino_patch_loss(self, features, masked_features, masks): + # for both supp and query features perform the patch wise loss + loss = 0.0 + weight = 0.1 + B = features.shape[0] + for (f, mf, mask) in zip(features, masked_features, masks): + # TODO sinkhorn knopp center features + f = f[mask] + f = self.sinkhorn_knopp_teacher(f) + mf = mf[mask] + loss += torch.sum(f * F.log_softmax(mf / 1, + dim=-1), dim=-1) / mask.sum() + + return -loss.sum() * weight / B diff --git a/models/segment_anything/__init__.py b/models/segment_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23e6d84e0e9041f4e72e2fab4d01a38744fcd571 --- /dev/null +++ b/models/segment_anything/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/models/segment_anything/__pycache__/__init__.cpython-312.pyc b/models/segment_anything/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cf08cb5fa1c8c737a4536ea809e560d6eeec88f Binary files /dev/null and b/models/segment_anything/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc b/models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e655b3a065480bcfd6b6de7b635bc41e92d284c Binary files /dev/null and b/models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc differ diff --git a/models/segment_anything/__pycache__/build_sam.cpython-312.pyc b/models/segment_anything/__pycache__/build_sam.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c991cb724e26b739b4db42bb7b3d8102d1bd14 Binary files /dev/null and b/models/segment_anything/__pycache__/build_sam.cpython-312.pyc differ diff --git a/models/segment_anything/__pycache__/predictor.cpython-312.pyc b/models/segment_anything/__pycache__/predictor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a754111fbc80dbcb37f34b21b0d85d259ecb91 Binary files /dev/null and b/models/segment_anything/__pycache__/predictor.cpython-312.pyc differ diff --git a/models/segment_anything/automatic_mask_generator.py b/models/segment_anything/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e8a52ae0500004786f1731820a9fbaaea56820 --- /dev/null +++ b/models/segment_anything/automatic_mask_generator.py @@ -0,0 +1,380 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + custom_points: bool = "false", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.custom_points = custom_points + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + if self.custom_points: + in_pos_labels = torch.ones(in_points.shape[0]//2, dtype=torch.int, device=in_points.device) + in_neg_labels = torch.zeros_like(in_pos_labels) + in_labels = torch.cat((in_pos_labels, in_neg_labels), dim = 0) + else: + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/models/segment_anything/build_sam.py b/models/segment_anything/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc353340ca4fd5c29db7dfae3494ffb4c4c9502 --- /dev/null +++ b/models/segment_anything/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, SamBatched + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = SamBatched( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/models/segment_anything/modeling/__init__.py b/models/segment_anything/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a25e343d958d3b60feb84b92f6948ffd1f2a21dd --- /dev/null +++ b/models/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam, SamBatched +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3833dcb89122f9a2b8d42f419928146c740dba27 Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/__pycache__/common.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47894cadaf295efbb99951120dfd18e384fd72f3 Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/common.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d43446357b93f8ae0dda00a4b4a4a60e06475edb Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2d3a49dfc211011182746852d5c6a45d819a4b Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32500d9dcf256a0f9b170d4fb09dc5fad1352914 Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8900c44916b52e74deed3305dfd151788f636e0f Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc b/models/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ab586ee2f5271ce8faf49ad244e31463ec1ea77 Binary files /dev/null and b/models/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc differ diff --git a/models/segment_anything/modeling/common.py b/models/segment_anything/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5c92073d1fd6a44d9a7f3abb9ab610d3ccbcac12 --- /dev/null +++ b/models/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/models/segment_anything/modeling/image_encoder.py b/models/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..224bffa56aad613c52ef8b2b469c1f518c05f7e3 --- /dev/null +++ b/models/segment_anything/modeling/image_encoder.py @@ -0,0 +1,406 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + use_grad_checkpointing: bool = False, + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + self.use_grad_checkpointing = use_grad_checkpointing + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + if self.use_grad_checkpointing: + blk.use_grad_checkpointing = True + x = torch.utils.checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + use_grad_checkpointing: bool = False, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.use_grad_checkpointing = use_grad_checkpointing + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + if self.use_grad_checkpointing: + x = torch.utils.checkpoint.checkpoint(self.attn, x) + else: + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/models/segment_anything/modeling/mask_decoder.py b/models/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4fdb868e1b0340d1bb6b1ee84a20eca27be455 --- /dev/null +++ b/models/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/models/segment_anything/modeling/prompt_encoder.py b/models/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4f73520ad1318da91f271a623c8497c8b9a31475 --- /dev/null +++ b/models/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/models/segment_anything/modeling/sam.py b/models/segment_anything/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..aac8cce7babf015a7ae3356e092aa081bcfa076f --- /dev/null +++ b/models/segment_anything/modeling/sam.py @@ -0,0 +1,333 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + # @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="nearest" + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="nearest") + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + +class SamBatched(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + # @torch.no_grad() + def forward( + self, + batched_input: torch.Tensor, + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + with torch.no_grad(): + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image_size"], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=True, + ) + masks = masks[..., : int(input_size[0]), : int(input_size[1])] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=True) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x \ No newline at end of file diff --git a/models/segment_anything/modeling/transformer.py b/models/segment_anything/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d99f8e8265b5780dd3be1d8c6bbd33156ac1d8f4 --- /dev/null +++ b/models/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/models/segment_anything/predictor.py b/models/segment_anything/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..dadf27ea8e9962418f7714d08ebfadcf9c4d3182 --- /dev/null +++ b/models/segment_anything/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/models/segment_anything/utils/__init__.py b/models/segment_anything/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/models/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/segment_anything/utils/__pycache__/__init__.cpython-312.pyc b/models/segment_anything/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08aafbdaeb0c6dec7e8adb88465ed99911a5e238 Binary files /dev/null and b/models/segment_anything/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/segment_anything/utils/__pycache__/amg.cpython-312.pyc b/models/segment_anything/utils/__pycache__/amg.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b23fbac1c6e6bbfcffec57c790994544e96139c Binary files /dev/null and b/models/segment_anything/utils/__pycache__/amg.cpython-312.pyc differ diff --git a/models/segment_anything/utils/__pycache__/transforms.cpython-312.pyc b/models/segment_anything/utils/__pycache__/transforms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ced6b7a813dc9fed97c230f180afe7620eec2ca Binary files /dev/null and b/models/segment_anything/utils/__pycache__/transforms.cpython-312.pyc differ diff --git a/models/segment_anything/utils/amg.py b/models/segment_anything/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..1464be421bbf473e0131e1be1c40e869b060a86d --- /dev/null +++ b/models/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/models/segment_anything/utils/onnx.py b/models/segment_anything/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a9d9e2f1c5990f6b279ef7d1bb847063c68e5e --- /dev/null +++ b/models/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/models/segment_anything/utils/transforms.py b/models/segment_anything/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..165d19be566815164dfb37d10d08e21d369d34dd --- /dev/null +++ b/models/segment_anything/utils/transforms.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore +from typing import List + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375],) -> None: + + self.target_length = target_length + self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) + self.pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) + + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[-2], image.shape[-1], self.target_length) + if len(image.shape) == 3: + image = image.unsqueeze(0) + image = F.interpolate( + image, target_size, + mode="bilinear", + align_corners=False, + antialias=True + ) + return image.squeeze(0) + elif len(image.shape) == 2: + image = image.unsqueeze(0).unsqueeze(0) + image = F.interpolate( + image, target_size, + mode="bilinear", + align_corners=False, + antialias=True + ) + return image.squeeze(0).squeeze(0) + + else: + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + if len(x.shape)==2: + pass + else: + device = x.device + x = (x - self.pixel_mean.to(device)) / self.pixel_std.to(device) # TODO uncomment this + # x = x / 255 + pass + + # Pad + h, w = x.shape[-2:] + padh = self.target_length - h + padw = self.target_length - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..326776d0f0dcb10584469b59f406f077188f9491 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +torch==2.6.0 +torchvision==0.21.0 +numpy==1.26.4 +matplotlib==3.10.3 +pillow==11.2.1 +opencv-python==4.11.0.86 +tqdm==4.67.1 +sacred==0.8.7 +gradio==5.29.0 +safetensors==0.5.3 +segment-anything==1.0 +kneed==0.8.5 +scikit-image==0.24.0 +scikit-learn==1.6.1 + +# tqdm==4.67.1 +# mediapipe==0.10.21 +# opencv-python==4.11.0.86 +# numpy==1.26.4 +# pandas==2.2.3 +# torch==2.6.0 +# torchvision==0.21.0 +# scikit-learn==1.6.1 +# matplotlib==3.10.1 +# pillow==11.2.1 +# gradio==5.29.0 diff --git a/run_demo.sh b/run_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..426aacbabf83f5f4b4f3aa6b1b9c34cd86b27f51 --- /dev/null +++ b/run_demo.sh @@ -0,0 +1,14 @@ +# Download SAM model if it doesn't exist +if [ ! -d "pretrained_model" ]; then + mkdir -p pretrained_model +fi + +if [ ! -f "pretrained_model/sam_vit_h.pth" ]; then + echo "Downloading SAM ViT-H model..." + wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth + mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth +fi + +# Run the app +echo "Running ProtoSAM demo..." +python app.py \ No newline at end of file diff --git a/run_protosam.sh b/run_protosam.sh new file mode 100644 index 0000000000000000000000000000000000000000..ea0d34165f6b35166a417f851554df1e638084cb --- /dev/null +++ b/run_protosam.sh @@ -0,0 +1,124 @@ +#!/bin/bash +set -e +GPUID1=0 +export CUDA_VISIBLE_DEVICES=$GPUID1 + +# Configs +MODEL_NAME='dinov2_l14' # relevant for ALPNET, aviailable: dinov2_l14, dinov2_l14_reg, dinov2_b14, dinov2_b14_reg, dlfcn_res101 (deeplabv3) +COARSE_PRED_ONLY="False" # True will output the coarse segmentation result +PROTOSAM_SAM_VER="sam_h" # available: sam_h, sam_b, medsam +INPUT_SIZE=256 # resolution +ORGAN="rk" # relevant for MRI and CT, available: rk, lk, liver, spleen + +# get modality as arg +MODALITY=$1 + +PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training +ALL_EV=( 0 ) # 5-fold cross validation (0, 1, 2, 3, 4) +SEED=42 + +if [ $MODALITY != "ct" ] && [ $MODALITY != "mri" ] && [ $MODALITY != "polyp" ] +then + echo "modality must be either ct ,mri or polyp" + exit 1 +fi + +if [ $MODALITY == "ct" ] +then + DATASET='SABS_Superpix' +fi +if [ $MODALITY == "mri" ] +then + DATASET='CHAOST2_Superpix' +fi +if [ $MODALITY == "polyp" ] +then + DATASET='polyps' +fi + +if [ $INPUT_SIZE -gt 256 ] +then + DATASET=${DATASET}'_672' +fi + +NWORKER=4 +LORA=0 +RELOAD_PATH=( "None" ) +SKIP_SLICES="True" +DO_CCA="True" +ALL_SCALE=( "MIDDLE") # config of pseudolabels + +if [ $MODALITY == "polyp" ] +then + ORGAN="polyps" +fi + +FREE_DESC="" +CPT="${MODEL_NAME}_${MODALITY}" +if [ -n "$FREE_DESC" ] +then + CPT="${CPT}_${FREE_DESC}" +fi + +if [ $LORA -ne 0 ] +then + CPT="${CPT}_lora_${LORA}" +fi + +if [ $DO_CCA = "True" ] +then + CPT="${CPT}_cca" +fi + +CPT="${CPT}_grid_${PROTO_GRID}_res_${INPUT_SIZE}_${ORGAN}_fold" + +SUPP_ID='[6]' +if [ $MODALITY == "mri" ] +then + SUPP_ID='[4]' +fi + +echo =================================== + +for ((i=0; i<${#ALL_EV[@]}; i++)) +do + EVAL_FOLD=${ALL_EV[i]} + CPT_W_FOLD="${CPT}_${EVAL_FOLD}" + echo $CPT_W_FOLD on GPU $GPUID1 + for SUPERPIX_SCALE in "${ALL_SCALE[@]}" + do + PREFIX="test_vfold${EVAL_FOLD}" + echo $PREFIX + LOGDIR="./test_${MODALITY}/${CPT_W_FOLD}" + + if [ ! -d $LOGDIR ] + then + mkdir -p $LOGDIR + fi + + python3 validation_protosam.py with \ + "modelname=$MODEL_NAME" \ + "base_model=alpnet" \ + "coarse_pred_only=$COARSE_PRED_ONLY" \ + "protosam_sam_ver=$PROTOSAM_SAM_VER" \ + "curr_cls=$ORGAN" \ + 'usealign=True' \ + 'optim_type=sgd' \ + reload_model_path=${RELOAD_PATH[i]} \ + num_workers=$NWORKER \ + scan_per_load=-1 \ + 'use_wce=True' \ + exp_prefix=$PREFIX \ + 'clsname=grid_proto' \ + eval_fold=$EVAL_FOLD \ + dataset=$DATASET \ + proto_grid_size=$PROTO_GRID \ + min_fg_data=1 seed=$SEED \ + save_snapshot_every=$SNAPSHOT_INTERVAL \ + superpix_scale=$SUPERPIX_SCALE \ + path.log_dir=$LOGDIR \ + support_idx=$SUPP_ID \ + lora=$LORA \ + "input_size=($INPUT_SIZE, $INPUT_SIZE)" + done +done \ No newline at end of file diff --git a/training.py b/training.py new file mode 100644 index 0000000000000000000000000000000000000000..1719c450f9856d02cd70bf8c33cbcb2ddae114ab --- /dev/null +++ b/training.py @@ -0,0 +1,250 @@ +""" +Training the model +Extended from original implementation of ALPNet. +""" +from scipy.ndimage import distance_transform_edt as eucl_distance +import os +import shutil +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import MultiStepLR +import numpy as np +from models.grid_proto_fewshot import FewShotSeg +from torch.utils.tensorboard import SummaryWriter +from dataloaders.dev_customized_med import med_fewshot +from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset +import dataloaders.augutils as myaug + +from util.utils import set_seed, t2n, to01, compose_wt_simple +from util.metric import Metric + +from config_ssl_upload import ex +from tqdm.auto import tqdm +# import Tensor +from torch import Tensor +from typing import List, Tuple, Union, cast, Iterable, Set, Any, Callable, TypeVar + +def get_dice_loss(prediction: torch.Tensor, target: torch.Tensor, smooth=1.0): + ''' + prediction: (B, 1, H, W) + target: (B, H, W) + ''' + if prediction.shape[1] > 1: + # use only the foreground prediction + prediction = prediction[:, 1, :, :] + prediction = torch.sigmoid(prediction) + intersection = (prediction * target).sum(dim=(-2, -1)) + union = prediction.sum(dim=(-2, -1)) + target.sum(dim=(1, 2)) + smooth + + dice = (2.0 * intersection + smooth) / union + dice_loss = 1.0 - dice.mean() + + return dice_loss + + +def get_train_transforms(_config): + tr_transforms = myaug.transform_with_label( + {'aug': myaug.get_aug(_config['which_aug'], _config['input_size'][0])}) + return tr_transforms + + +def get_dataset_base_name(data_name): + if data_name == 'SABS_Superpix': + baseset_name = 'SABS' + elif data_name == 'C0_Superpix': + raise NotImplementedError + baseset_name = 'C0' + elif data_name == 'CHAOST2_Superpix': + baseset_name = 'CHAOST2' + elif data_name == 'CHAOST2_Superpix_672': + baseset_name = 'CHAOST2' + elif data_name == 'SABS_Superpix_448': + baseset_name = 'SABS' + elif data_name == 'SABS_Superpix_672': + baseset_name = 'SABS' + elif 'lits' in data_name.lower(): + baseset_name = 'LITS17' + else: + raise ValueError(f'Dataset: {data_name} not found') + + return baseset_name + +def get_nii_dataset(_config): + data_name = _config['dataset'] + baseset_name = get_dataset_base_name(data_name) + tr_transforms = get_train_transforms(_config) + tr_parent = SuperpixelDataset( # base dataset + which_dataset=baseset_name, + base_dir=_config['path'][data_name]['data_dir'], + idx_split=_config['eval_fold'], + mode='train', + # dummy entry for superpixel dataset + min_fg=str(_config["min_fg_data"]), + image_size=_config["input_size"][0], + transforms=tr_transforms, + nsup=_config['task']['n_shots'], + scan_per_load=_config['scan_per_load'], + exclude_list=_config["exclude_cls_list"], + superpix_scale=_config["superpix_scale"], + fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or ( + data_name == 'CHAOST2_Superpix') else _config["max_iters_per_load"], + use_clahe=_config['use_clahe'], + use_3_slices=_config["use_3_slices"], + tile_z_dim=3 if not _config["use_3_slices"] else 1, + ) + + return tr_parent + + +def get_dataset(_config): + return get_nii_dataset(_config) + + +@ex.automain +def main(_run, _config, _log): + precision = torch.float32 + torch.autograd.set_detect_anomaly(True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if _run.observers: + os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) + for source_file, _ in _run.experiment_info['sources']: + os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), + exist_ok=True) + _run.observers[0].save_file(source_file, f'source/{source_file}') + shutil.rmtree(f'{_run.observers[0].basedir}/_sources') + + set_seed(_config['seed']) + + writer = SummaryWriter(f'{_run.observers[0].dir}/logs') + _log.info('###### Create model ######') + if _config['reload_model_path'] != '': + _log.info(f'###### Reload model {_config["reload_model_path"]} ######') + else: + _config['reload_model_path'] = None + model = FewShotSeg(image_size=_config['input_size'][0], pretrained_path=_config['reload_model_path'], cfg=_config['model']) + + model = model.to(device, precision) + model.train() + + _log.info('###### Load data ######') + data_name = _config['dataset'] + tr_parent = get_dataset(_config) + + # dataloaders + trainloader = DataLoader( + tr_parent, + batch_size=_config['batch_size'], + shuffle=True, + num_workers=_config['num_workers'], + pin_memory=True, + drop_last=True + ) + + _log.info('###### Set optimizer ######') + if _config['optim_type'] == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) + elif _config['optim_type'] == 'adam': + optimizer = torch.optim.AdamW( + model.parameters(), lr=_config['lr'], eps=1e-5) + else: + raise NotImplementedError + + scheduler = MultiStepLR( + optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma']) + + my_weight = compose_wt_simple(_config["use_wce"], data_name) + criterion = nn.CrossEntropyLoss( + ignore_index=_config['ignore_label'], weight=my_weight) + + i_iter = 0 # total number of iteration + # number of times for reloading + n_sub_epoches = max(1, _config['n_steps'] // _config['max_iters_per_load'], _config["epochs"]) + log_loss = {'loss': 0, 'align_loss': 0} + + _log.info('###### Training ######') + epoch_losses = [] + for sub_epoch in range(1): + print(f"Epoch: {sub_epoch}") + _log.info( + f'###### This is epoch {sub_epoch} of {n_sub_epoches} epoches ######') + pbar = tqdm(trainloader) + optimizer.zero_grad() + for idx, sample_batched in enumerate(tqdm(trainloader)): + losses = [] + i_iter += 1 + support_images = [[shot.to(device, precision) for shot in way] + for way in sample_batched['support_images']] + support_fg_mask = [[shot[f'fg_mask'].float().to(device, precision) for shot in way] + for way in sample_batched['support_mask']] + support_bg_mask = [[shot[f'bg_mask'].float().to(device, precision) for shot in way] + for way in sample_batched['support_mask']] + + query_images = [query_image.to(device, precision) + for query_image in sample_batched['query_images']] + query_labels = torch.cat( + [query_label.long().to(device) for query_label in sample_batched['query_labels']], dim=0) + + loss = 0.0 + try: + out = model(support_images, support_fg_mask, support_bg_mask, + query_images, isval=False, val_wsize=None) + query_pred, align_loss, _, _, _, _, _ = out + # pred = np.array(query_pred.argmax(dim=1)[0].cpu()) + except Exception as e: + print(f'faulty batch detected, skip: {e}') + # offload cuda memory + del support_images, support_fg_mask, support_bg_mask, query_images, query_labels + continue + + query_loss = criterion(query_pred.float(), query_labels.long()) + loss += query_loss + align_loss + pbar.set_postfix({'loss': loss.item()}) + loss.backward() + if (idx + 1) % _config['grad_accumulation_steps'] == 0: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + + losses.append(loss.item()) + query_loss = query_loss.detach().data.cpu().numpy() + align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0 + + _run.log_scalar('loss', query_loss) + _run.log_scalar('align_loss', align_loss) + + log_loss['loss'] += query_loss + log_loss['align_loss'] += align_loss + + # print loss and take snapshots + if (i_iter + 1) % _config['print_interval'] == 0: + writer.add_scalar('loss', loss, i_iter) + writer.add_scalar('query_loss', query_loss, i_iter) + writer.add_scalar('align_loss', align_loss, i_iter) + + loss = log_loss['loss'] / _config['print_interval'] + align_loss = log_loss['align_loss'] / _config['print_interval'] + + log_loss['loss'] = 0 + log_loss['align_loss'] = 0 + + print( + f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss},') + + if (i_iter + 1) % _config['save_snapshot_every'] == 0: + _log.info('###### Taking snapshot ######') + torch.save(model.state_dict(), + os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) + + if (i_iter - 1) >= _config['n_steps']: + break # finish up + epoch_losses.append(np.mean(losses)) + print(f"Epoch {sub_epoch} loss: {np.mean(losses)}") + + # Save the final model regardless of iteration count + _log.info('###### Saving final model ######') + final_save_path = os.path.join(f'{_run.observers[0].dir}/snapshots', f'final_model.pth') + torch.save(model.state_dict(), final_save_path) + print(f"Final model saved to: {final_save_path}") diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/util/__pycache__/__init__.cpython-312.pyc b/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f096d96f46d25d43b251460b7a559325624530fc Binary files /dev/null and b/util/__pycache__/__init__.cpython-312.pyc differ diff --git a/util/__pycache__/consts.cpython-312.pyc b/util/__pycache__/consts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4830266e1486f0303f5416abcd9fa019d98834b9 Binary files /dev/null and b/util/__pycache__/consts.cpython-312.pyc differ diff --git a/util/__pycache__/lora.cpython-312.pyc b/util/__pycache__/lora.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed843c281fdb1bdd443dd31d1a62184b1e04a3b Binary files /dev/null and b/util/__pycache__/lora.cpython-312.pyc differ diff --git a/util/__pycache__/utils.cpython-312.pyc b/util/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50d5cd701cd4b13d759acdb6df2313f50dd6f195 Binary files /dev/null and b/util/__pycache__/utils.cpython-312.pyc differ diff --git a/util/consts.py b/util/consts.py new file mode 100644 index 0000000000000000000000000000000000000000..d82d6dc0e9d35dd275aa49ec0c7ed4ebfbff3445 --- /dev/null +++ b/util/consts.py @@ -0,0 +1,2 @@ +IMG_SIZE=252 # 256 is original +DEFAULT_FEATURE_SIZE=32 \ No newline at end of file diff --git a/util/lora.py b/util/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2533a01fa77df7de5b35e7e63a62bc5f38e5cf61 --- /dev/null +++ b/util/lora.py @@ -0,0 +1,1113 @@ +# copied from https://github.com/cloneofsimo/lora +import json +import math +from itertools import groupby +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from safetensors.torch import safe_open + from safetensors.torch import save_file as safe_save + + safetensors_available = True +except ImportError: + from .safe_open import safe_open + + def safe_save( + tensors: Dict[str, torch.Tensor], + filename: str, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + raise EnvironmentError( + "Saving safetensors requires the safetensors library. Please install with pip or similar." + ) + + safetensors_available = False + + +class LoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 + ): + super().__init__() + + if r > min(in_features, out_features): + raise ValueError( + f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + ) + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = scale + self.selector = nn.Identity() + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.linear(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class LoraInjectedConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + r: int = 4, + dropout_p: float = 0.1, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + raise ValueError( + f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}" + ) + self.r = r + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.lora_down = nn.Conv2d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv2d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv2d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} + +TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} + +TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} + +DINO_TARGET_REPLACE = {"NestedTensorBlock", "Mlp", "Attention", "MemEffAttention"} + +DEFAULT_TARGET_REPLACE = DINO_TARGET_REPLACE + +EMBED_FLAG = "" + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoraInjectedLinear, + LoraInjectedConv2d, + ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = ( + module + for module in model.modules() + if module.__class__.__name__ in ancestor_class + ) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): + continue + # Otherwise, yield it + yield parent, name, module + + +def _find_modules_old( + model, + ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], +): + ret = [] + for _module in model.modules(): + if _module.__class__.__name__ in ancestor_class: + + for name, _child_module in _module.named_modules(): + if _child_module.__class__ in search_class: + ret.append((_module, name, _child_module)) + print(ret) + return ret + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora( + model: nn.Module, + target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + verbose: bool = False, + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear] + ): + weight = _child_module.weight + bias = _child_module.bias + if verbose: + print("LoRA Injection : injecting lora into ", name) + print("LoRA Injection : weight shape", weight.shape) + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], + ): + loras.append((_child_module.lora_up, _child_module.lora_down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def extract_lora_as_tensor( + model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True +): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], + ): + up, down = _child_module.realize_as_lora() + if as_fp16: + up = up.to(torch.float16) + down = down.to(torch.float16) + + loras.append((up, down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def save_lora_weight( + model, + path="./lora.pt", + target_replace_module=DEFAULT_TARGET_REPLACE, +): + weights = [] + for _up, _down in extract_lora_ups_down( + model, target_replace_module=target_replace_module + ): + weights.append(_up.weight.to("cpu").to(torch.float16)) + weights.append(_down.weight.to("cpu").to(torch.float16)) + + torch.save(weights, path) + + +def save_lora_as_json(model, path="./lora.json"): + weights = [] + for _up, _down in extract_lora_ups_down(model): + weights.append(_up.weight.detach().cpu().numpy().tolist()) + weights.append(_down.weight.detach().cpu().numpy().tolist()) + + import json + + with open(path, "w") as f: + json.dump(weights, f) + + +def save_safeloras_with_embeds( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Saves the Lora from multiple modules in a single safetensor file. + + modelmap is a dictionary of { + "module name": (module, target_replace_module) + } + """ + weights = {} + metadata = {} + + for name, (model, target_replace_module) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + for i, (_up, _down) in enumerate( + extract_lora_as_tensor(model, target_replace_module) + ): + rank = _down.shape[0] + + metadata[f"{name}:{i}:rank"] = str(rank) + weights[f"{name}:{i}:up"] = _up + weights[f"{name}:{i}:down"] = _down + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def save_safeloras( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + outpath="./lora.safetensors", +): + return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def convert_loras_to_safeloras_with_embeds( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Converts the Lora from multiple pytorch .pt files into a single safetensor file. + + modelmap is a dictionary of { + "module name": (pytorch_model_path, target_replace_module, rank) + } + """ + + weights = {} + metadata = {} + + for name, (path, target_replace_module, r) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + lora = torch.load(path) + for i, weight in enumerate(lora): + is_up = i % 2 == 0 + i = i // 2 + + if is_up: + metadata[f"{name}:{i}:rank"] = str(r) + weights[f"{name}:{i}:up"] = weight + else: + weights[f"{name}:{i}:down"] = weight + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def convert_loras_to_safeloras( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + outpath="./lora.safetensors", +): + convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def parse_safeloras( + safeloras, +) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: + """ + Converts a loaded safetensor file that contains a set of module Loras + into Parameters and other information + + Output is a dictionary of { + "module name": ( + [list of weights], + [list of ranks], + target_replacement_modules + ) + } + """ + loras = {} + metadata = safeloras.metadata() + + get_name = lambda k: k.split(":")[0] + + keys = list(safeloras.keys()) + keys.sort(key=get_name) + + for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras + # Extract the targets + target = json.loads(info) + + # Build the result lists - Python needs us to preallocate lists to insert into them + module_keys = list(module_keys) + ranks = [4] * (len(module_keys) // 2) + weights = [None] * len(module_keys) + + for key in module_keys: + # Split the model name and index out of the key + _, idx, direction = key.split(":") + idx = int(idx) + + # Add the rank + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) + + # Insert the weight into the list + idx = idx * 2 + (1 if direction == "down" else 0) + weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) + + loras[name] = (weights, ranks, target) + + return loras + + +def parse_safeloras_embeds( + safeloras, +) -> Dict[str, torch.Tensor]: + """ + Converts a loaded safetensor file that contains Textual Inversion embeds into + a dictionary of embed_token: Tensor + """ + embeds = {} + metadata = safeloras.metadata() + + for key in safeloras.keys(): + # Only handle Textual Inversion embeds + meta = metadata.get(key) + if not meta or meta != EMBED_FLAG: + continue + + embeds[key] = safeloras.get_tensor(key) + + return embeds + + +def load_safeloras(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras) + + +def load_safeloras_embeds(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras_embeds(safeloras) + + +def load_safeloras_both(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) + + +def collapse_lora(model, alpha=1.0): + + for _module, name, _child_module in _find_modules( + model, + UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], + ): + + if isinstance(_child_module, LoraInjectedLinear): + print("Collapsing Lin Lora in", name) + + _child_module.linear.weight = nn.Parameter( + _child_module.linear.weight.data + + alpha + * ( + _child_module.lora_up.weight.data + @ _child_module.lora_down.weight.data + ) + .type(_child_module.linear.weight.dtype) + .to(_child_module.linear.weight.device) + ) + + else: + print("Collapsing Conv Lora in", name) + _child_module.conv.weight = nn.Parameter( + _child_module.conv.weight.data + + alpha + * ( + _child_module.lora_up.weight.data.flatten(start_dim=1) + @ _child_module.lora_down.weight.data.flatten(start_dim=1) + ) + .reshape(_child_module.conv.weight.data.shape) + .type(_child_module.conv.weight.dtype) + .to(_child_module.conv.weight.device) + ) + + +def monkeypatch_or_replace_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + ): + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_lora_extended( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, + target_replace_module, + search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d], + ): + + if (_child_module.__class__ == nn.Linear) or ( + _child_module.__class__ == LoraInjectedLinear + ): + if len(loras[0].shape) != 2: + continue + + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + elif (_child_module.__class__ == nn.Conv2d) or ( + _child_module.__class__ == LoraInjectedConv2d + ): + if len(loras[0].shape) != 4: + continue + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv2d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv2d( + _source.in_channels, + _source.out_channels, + _source.kernel_size, + _source.stride, + _source.padding, + _source.dilation, + _source.groups, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_safeloras(models, safeloras): + loras = parse_safeloras(safeloras) + + for name, (lora, ranks, target) in loras.items(): + model = getattr(models, name, None) + + if not model: + print(f"No model provided for {name}, contained in Lora") + continue + + monkeypatch_or_replace_lora_extended(model, lora, target, ranks) + + +def monkeypatch_remove_lora(model): + for _module, name, _child_module in _find_modules( + model, search_class=[LoraInjectedLinear, LoraInjectedConv2d] + ): + if isinstance(_child_module, LoraInjectedLinear): + _source = _child_module.linear + weight, bias = _source.weight, _source.bias + + _tmp = nn.Linear( + _source.in_features, _source.out_features, bias is not None + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + else: + _source = _child_module.conv + weight, bias = _source.weight, _source.bias + + _tmp = nn.Conv2d( + in_channels=_source.in_channels, + out_channels=_source.out_channels, + kernel_size=_source.kernel_size, + stride=_source.stride, + padding=_source.padding, + dilation=_source.dilation, + groups=_source.groups, + bias=bias is not None, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + _module._modules[name] = _tmp + + +def monkeypatch_add_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + alpha: float = 1.0, + beta: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[LoraInjectedLinear] + ): + weight = _child_module.linear.weight + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_up.weight.to(weight.device) * beta + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_down.weight.to(weight.device) * beta + ) + + _module._modules[name].to(weight.device) + + +def tune_lora_scale(model, alpha: float = 1.0): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + _module.scale = alpha + + +def set_lora_diag(model, diag: torch.Tensor): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + _module.set_selector_from_diag(diag) + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def _ti_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["ti", "pt"]) + + +def apply_learned_embed_in_clip( + learned_embeds, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + if isinstance(token, str): + trained_tokens = [token] + elif isinstance(token, list): + assert len(learned_embeds.keys()) == len( + token + ), "The number of tokens and the number of embeds should be the same" + trained_tokens = token + else: + trained_tokens = list(learned_embeds.keys()) + + for token in trained_tokens: + print(token) + embeds = learned_embeds[token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + num_added_tokens = tokenizer.add_tokens(token) + + i = 1 + if not idempotent: + while num_added_tokens == 0: + print(f"The tokenizer already contains the token {token}.") + token = f"{token[:-1]}-{i}>" + print(f"Attempting to add the token {token}.") + num_added_tokens = tokenizer.add_tokens(token) + i += 1 + elif num_added_tokens == 0 and idempotent: + print(f"The tokenizer already contains the token {token}.") + print(f"Replacing {token} embedding.") + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token + + +def load_learned_embed_in_clip( + learned_embeds_path, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + learned_embeds = torch.load(learned_embeds_path) + apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, token, idempotent + ) + + +def patch_pipe( + pipe, + maybe_unet_path, + token: Optional[str] = None, + r: int = 4, + patch_unet=True, + patch_text=True, + patch_ti=True, + idempotent_token=True, + unet_target_replace_module=DEFAULT_TARGET_REPLACE, + text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +): + if maybe_unet_path.endswith(".pt"): + # torch format + + if maybe_unet_path.endswith(".ti.pt"): + unet_path = maybe_unet_path[:-6] + ".pt" + elif maybe_unet_path.endswith(".text_encoder.pt"): + unet_path = maybe_unet_path[:-16] + ".pt" + else: + unet_path = maybe_unet_path + + ti_path = _ti_lora_path(unet_path) + text_path = _text_lora_path(unet_path) + + if patch_unet: + print("LoRA : Patching Unet") + monkeypatch_or_replace_lora( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) + + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r, + ) + if patch_ti: + print("LoRA : Patching token input") + token = load_learned_embed_in_clip( + ti_path, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + + elif maybe_unet_path.endswith(".safetensors"): + safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") + monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + if patch_ti: + apply_learned_embed_in_clip( + tok_dict, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + return tok_dict + + +@torch.no_grad() +def inspect_lora(model): + moved = {} + + for name, _module in model.named_modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + ups = _module.lora_up.weight.data.clone() + downs = _module.lora_down.weight.data.clone() + + wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) + + dist = wght.flatten().abs().mean().item() + if name in moved: + moved[name].append(dist) + else: + moved[name] = [dist] + + return moved + + +def save_all( + unet, + text_encoder, + save_path, + placeholder_token_ids=None, + placeholder_tokens=None, + save_lora=True, + save_ti=True, + target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + target_replace_module_unet=DEFAULT_TARGET_REPLACE, + safe_form=True, +): + if not safe_form: + # save ti + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() + + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) + + # save text encoder + if save_lora: + + save_lora_weight( + unet, save_path, target_replace_module=target_replace_module_unet + ) + print("Unet saved to ", save_path) + + save_lora_weight( + text_encoder, + _text_lora_path(save_path), + target_replace_module=target_replace_module_text, + ) + print("Text Encoder saved to ", _text_lora_path(save_path)) + + else: + assert save_path.endswith( + ".safetensors" + ), f"Save path : {save_path} should end with .safetensors" + + loras = {} + embeds = {} + + if save_lora: + + loras["unet"] = (unet, target_replace_module_unet) + loras["text_encoder"] = (text_encoder, target_replace_module_text) + + if save_ti: + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + embeds[tok] = learned_embeds.detach().cpu() + + save_safeloras_with_embeds(loras, embeds, save_path) diff --git a/util/metric.py b/util/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..6c42a703ec1b4a5c7fc36811547a194375d4da6c --- /dev/null +++ b/util/metric.py @@ -0,0 +1,285 @@ +""" +Metrics for computing evalutation results +Modified from vanilla PANet code by Wang et al. +""" + +import numpy as np + +class Metric(object): + """ + Compute evaluation result + + Args: + max_label: + max label index in the data (0 denoting background) + n_scans: + number of test scans + """ + def __init__(self, max_label=20, n_scans=None): + self.labels = list(range(max_label + 1)) # all class labels + self.n_scans = 1 if n_scans is None else n_scans + + # list of list of array, each array save the TP/FP/FN statistic of a testing sample + self.tp_lst = [[] for _ in range(self.n_scans)] + self.fp_lst = [[] for _ in range(self.n_scans)] + self.fn_lst = [[] for _ in range(self.n_scans)] + self.slice_counter = [0 for _ in range(self.n_scans)] + + def reset(self): + """ + Reset accumulated evaluation. + """ + # assert self.n_scans == 1, 'Should not reset accumulated result when we are not doing one-time batch-wise validation' + del self.tp_lst, self.fp_lst, self.fn_lst + self.tp_lst = [[] for _ in range(self.n_scans)] + self.fp_lst = [[] for _ in range(self.n_scans)] + self.fn_lst = [[] for _ in range(self.n_scans)] + + def reset_scan(self, n_scan, labels:list=None): + """ + Reset accumulated evaluation for a specific scan. + """ + if labels is None: + labels = self.labels + for slice_idx in range(len(self.tp_lst[n_scan])): + for label in labels: + self.tp_lst[n_scan][slice_idx][label] = np.nan + self.fp_lst[n_scan][slice_idx][label] = np.nan + self.fn_lst[n_scan][slice_idx][label] = np.nan + + def record(self, pred, target, labels=None, n_scan=None): + """ + Record the evaluation result for each sample and each class label, including: + True Positive, False Positive, False Negative + + Args: + pred: + predicted mask array, expected shape is H x W + target: + target mask array, expected shape is H x W + labels: + only count specific label, used when knowing all possible labels in advance + """ + assert pred.shape == target.shape + + if self.n_scans == 1: + n_scan = 0 + + # array to save the TP/FP/FN statistic for each class (plus BG) + tp_arr = np.full(len(self.labels), np.nan) + fp_arr = np.full(len(self.labels), np.nan) + fn_arr = np.full(len(self.labels), np.nan) + if labels is None: + labels = self.labels + else: + labels = [0,] + labels + + for j, label in enumerate(labels): + # Get the location of the pixels that are predicted as class j + # idx = np.where(np.logical_and(pred == j, target != 255)) + # pred_idx_j = set(zip(idx[0].tolist(), idx[1].tolist())) + # # Get the location of the pixels that are class j in ground truth + # idx = np.where(target == j) + # target_idx_j = set(zip(idx[0].tolist(), idx[1].tolist())) + + # # this should not work: if target_idx_j: # if ground-truth contains this class + # # the author is adding posion to the code + # tp_arr[label] = len(set.intersection(pred_idx_j, target_idx_j)) + # fp_arr[label] = len(pred_idx_j - target_idx_j) + # fn_arr[label] = len(target_idx_j - pred_idx_j) + + # calc the tp, fp and fn normally and compare the 2 values + tp = ((pred == j).astype(int) * (target == j).astype(int)).sum() + fp = ((pred == j).astype(int) * (target != j).astype(int)).sum() + fn = ((pred != j).astype(int) * (target == j).astype(int)).sum() + + tp_arr[label] = tp + fp_arr[label] = fp + fn_arr[label] = fn + + # assert tp == tp_arr[label] + # assert fp == fp_arr[label] + # assert fn == fn_arr[label] + + self.tp_lst[n_scan].append(tp_arr) + self.fp_lst[n_scan].append(fp_arr) + self.fn_lst[n_scan].append(fn_arr) + self.slice_counter[n_scan] += 1 + + def get_mIoU(self, labels=None, n_scan=None): + """ + Compute mean IoU + + Args: + labels: + specify a subset of labels to compute mean IoU, default is using all classes + """ + if labels is None: + labels = self.labels + # Sum TP, FP, FN statistic of all samples + if n_scan is None: + tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + + # Compute mean IoU classwisely + # Average across n_scans, then average over classes + mIoU_class = np.vstack([tp_sum[_scan] / (tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) + for _scan in range(self.n_scans)]) + mIoU = mIoU_class.mean(axis=1) + + return (mIoU_class.mean(axis=0), mIoU_class.std(axis=0), + mIoU.mean(axis=0), mIoU.std(axis=0)) + else: + tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) + fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) + fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) + + # Compute mean IoU classwisely and average over classes + mIoU_class = tp_sum / (tp_sum + fp_sum + fn_sum) + mIoU = mIoU_class.mean() + + return mIoU_class, mIoU + + def get_mDice(self, labels=None, n_scan=None, give_raw = False): + """ + Compute mean Dice score (in 3D scan level) + + Args: + labels: + specify a subset of labels to compute mean IoU, default is using all classes + """ + # NOTE: unverified + if labels is None: + labels = self.labels + # Sum TP, FP, FN statistic of all samples + if n_scan is None: + tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + + # Average across n_scans, then average over classes + mDice_class = np.vstack([ 2 * tp_sum[_scan] / ( 2 * tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) + for _scan in range(self.n_scans)]) + mDice = mDice_class.mean(axis=1) + print(f"mDice_class:\n {mDice_class}") + if not give_raw: + return (mDice_class.mean(axis=0), mDice_class.std(axis=0), + mDice.mean(axis=0), mDice.std(axis=0)) + else: + return (mDice_class.mean(axis=0), mDice_class.std(axis=0), + mDice.mean(axis=0), mDice.std(axis=0), mDice_class) + + else: + tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) + fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) + fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) + + # Compute mean IoU classwisely and average over classes + mDice_class = 2 * tp_sum / ( 2 * tp_sum + fp_sum + fn_sum) + mDice = mDice_class.mean() + + if not give_raw: + return (mDice_class, mDice, mDice_class) + + return (mDice_class, mDice, mDice_class) + + def get_mPrecRecall(self, labels=None, n_scan=None, give_raw = False): + """ + Compute precision and recall + + Args: + labels: + specify a subset of labels to compute mean IoU, default is using all classes + """ + # NOTE: unverified + if labels is None: + labels = self.labels + # Sum TP, FP, FN statistic of all samples + if n_scan is None: + tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) + for _scan in range(self.n_scans)] + + # Compute mean IoU classwisely + # Average across n_scans, then average over classes + mPrec_class = np.vstack([ tp_sum[_scan] / ( tp_sum[_scan] + fp_sum[_scan] ) + for _scan in range(self.n_scans)]) + + mRec_class = np.vstack([ tp_sum[_scan] / ( tp_sum[_scan] + fn_sum[_scan] ) + for _scan in range(self.n_scans)]) + + mPrec = mPrec_class.mean(axis=1) + mRec = mRec_class.mean(axis=1) + if not give_raw: + return (mPrec_class.mean(axis=0), mPrec_class.std(axis=0), mPrec.mean(axis=0), mPrec.std(axis=0), mRec_class.mean(axis=0), mRec_class.std(axis=0), mRec.mean(axis=0), mRec.std(axis=0)) + else: + return (mPrec_class.mean(axis=0), mPrec_class.std(axis=0), mPrec.mean(axis=0), mPrec.std(axis=0), mRec_class.mean(axis=0), mRec_class.std(axis=0), mRec.mean(axis=0), mRec.std(axis=0), mPrec_class, mRec_class) + + + else: + tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) + fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) + fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) + + # Compute mean IoU classwisely and average over classes + mPrec_class = tp_sum / (tp_sum + fp_sum) + mPrec = mPrec_class.mean() + + mRec_class = tp_sum / (tp_sum + fn_sum) + mRec = mRec_class.mean() + + return mPrec_class, None, mPrec, None, mRec_class, None, mRec, None, mPrec_class, mRec_class + + def get_mIoU_binary(self, n_scan=None): + """ + Compute mean IoU for binary scenario + (sum all foreground classes as one class) + """ + # Sum TP, FP, FN statistic of all samples + if n_scan is None: + tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0) + for _scan in range(self.n_scans)] + fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0) + for _scan in range(self.n_scans)] + fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0) + for _scan in range(self.n_scans)] + + # Sum over all foreground classes + tp_sum = [np.c_[tp_sum[_scan][0], np.nansum(tp_sum[_scan][1:])] + for _scan in range(self.n_scans)] + fp_sum = [np.c_[fp_sum[_scan][0], np.nansum(fp_sum[_scan][1:])] + for _scan in range(self.n_scans)] + fn_sum = [np.c_[fn_sum[_scan][0], np.nansum(fn_sum[_scan][1:])] + for _scan in range(self.n_scans)] + + # Compute mean IoU classwisely and average across classes + mIoU_class = np.vstack([tp_sum[_scan] / (tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) + for _scan in range(self.n_scans)]) + mIoU = mIoU_class.mean(axis=1) + + return (mIoU_class.mean(axis=0), mIoU_class.std(axis=0), + mIoU.mean(axis=0), mIoU.std(axis=0)) + else: + tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0) + fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0) + fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0) + + # Sum over all foreground classes + tp_sum = np.c_[tp_sum[0], np.nansum(tp_sum[1:])] + fp_sum = np.c_[fp_sum[0], np.nansum(fp_sum[1:])] + fn_sum = np.c_[fn_sum[0], np.nansum(fn_sum[1:])] + + mIoU_class = tp_sum / (tp_sum + fp_sum + fn_sum) + mIoU = mIoU_class.mean() + + return mIoU_class, mIoU diff --git a/util/utils.py b/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a451de2307322656f626c683fbd9346f90eec8 --- /dev/null +++ b/util/utils.py @@ -0,0 +1,652 @@ +"""Util functions +Extended from original PANet code +TODO: move part of dataset configurations to data_utils +""" +import random +import torch +import numpy as np +import operator +import cv2 +import matplotlib.pyplot as plt +import kneed +import urllib +from tqdm.auto import tqdm +from sklearn.decomposition import PCA +import torchvision.transforms.functional as F + + +def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"): + num_labels, labels, stats, centroids = cca_output + # Create an output image with random colors for each component + output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8) + for label in range(1, num_labels): # Start from 1 to skip the background + mask = labels == label + output_image[mask] = np.random.randint(0, 255, size=3) + + # Plotting the original and the colored components image + plt.figure(figsize=(10, 5)) + plt.subplot(121), plt.imshow(original_image), plt.title('Original Image') + plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components') + if confidences is not None: + # Plot the axes color chart with the confidences, use the same colors as the connected components + plt.subplot(122) + scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet') + plt.colorbar(scatter) + + plt.savefig(title) + plt.close() + + +def reverse_tensor(tensor, original_h, original_w, degrees): + """ + tensor: tensor of shape (B, C, H, W) to be rotated + original_h: int - original height of the tensor (after it was rotated) + original_w: int - original width of the tensor (after it was rotated) + degrees: int or float - angle in degrees couterclockwise + """ + _, _, h, w = tensor.shape # this is the shape that we want to return to + if tensor.shape[-2:] != (original_h, original_w): + tensor = F.resize(tensor, (original_h, original_w), interpolation=F.InterpolationMode.BILINEAR, antialias=True) + # print("interpolating") + + rotated_tensor = F.rotate(tensor, degrees, expand=False) + # remove the black padding + h_remove = abs(h - original_h) // 2 + w_remove = abs(w - original_w) // 2 + if h_remove > 0 and w_remove > 0: + rotated_tensor = rotated_tensor[:, :, h_remove:-h_remove, w_remove:-w_remove] + + return rotated_tensor + + +def need_softmax(tensor, dim=1): + return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0)) + + +def rotate_tensor_no_crop(image_tensor, degrees): + """ + image_tensor: tensor of shape (B, C, H, W) + degrees: int or float - angle in degrees couterclockwise + returns: tensor of shape (B, C, H, W) rotated by degrees, + """ + if degrees == 0: + return image_tensor, image_tensor.shape[-2:] + + b, c, h, w = image_tensor.shape + rotated_tensor = F.rotate(image_tensor, degrees, expand=True) + + interpolation_mode = F.InterpolationMode.BILINEAR + if c == 1: + interpolation_mode = F.InterpolationMode.NEAREST + resized_tensor = F.resize(rotated_tensor, (h, w), interpolation=interpolation_mode, antialias=True) + + return resized_tensor, rotated_tensor.shape[-2:] + +def plot_dinov2_fts(img_fts, title="debug/img_fts.png"): + """ + Using PCA to reduce img_fts to 2D and plot it + Args: + img_fts: (B, C, H, W) + """ + if isinstance(img_fts, torch.Tensor): + img_fts = img_fts.cpu().detach().numpy() + + B, C, H, W = img_fts.shape + + img_fts_reshaped = img_fts.transpose(0, 2, 3, 1).reshape(-1, C) + + # Apply PCA to reduce dimensionality from C to 1 + pca = PCA(n_components=1) + img_fts_pca = pca.fit_transform(img_fts_reshaped) + + # Reshape back to (B, 1, H, W) + img_fts_reduced = img_fts_pca.reshape(B, H, W, 1).transpose(0, 3, 1, 2) + + # Plot the B images + if B == 1: + fig, ax = plt.subplots(figsize=(5, 5)) + ax.imshow(img_fts_reduced[0, 0]) + else: + fig, axes = plt.subplots(1, B, figsize=(B*5, 5)) + for i, ax in enumerate(axes.flat): + ax.imshow(img_fts_reduced[i, 0]) + # ax.axis('off') + + plt.tight_layout() + plt.savefig(title) + plt.close(fig) + + +def move_to_device(dict_obj, device='cuda'): + for key in dict_obj: + value = dict_obj[key] + if isinstance(value, torch.Tensor): + dict_obj[key] = value.to(device) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, torch.Tensor): + dict_obj[key][i] = item.to(device) + + +def validation_single_slice(model, support_images, support_fg_mask, support_bg_mask, query_images, _config, q_part=0): + model.eval() + + sup_img_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][q_part]]] # way(1) x shot x [B(1) x C x H x W] + sup_fgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][q_part]]] + sup_bgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][q_part]]] + + with torch.no_grad(): + query_pred_logits, _, _, assign_mats, _, _ = model( sup_img_part , sup_fgm_part, sup_bgm_part, query_images, isval = True, val_wsize = _config["val_wsize"] ) + + query_pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu().detach()) + + if _config['do_cca']: + query_pred = cca(query_pred, query_pred_logits) + + if _config["debug"]: + # plot the support images, support fg mask, query image, query pred before cca and query pred after cca + fig, ax = plt.subplots(3, 2, figsize=(15, 10)) + ax[0,0].imshow(support_images[0][q_part][0,0].cpu().numpy(), cmap='gray') + ax[0,1].imshow(support_fg_mask[0][q_part][0].cpu().numpy(), cmap='gray') + ax[1,0].imshow(query_images[0][0][0].cpu().numpy(), cmap='gray') + ax[1,1].imshow(query_pred_logits.argmax(dim=1)[0].cpu().detach().numpy(), cmap='gray') + ax[2,0].imshow(query_pred, cmap='gray') + ax[2,1].imshow(query_pred_logits.argmax(dim=1)[0].cpu().detach().numpy(), cmap='gray') + # remove all ticks + for axi in ax.flat: + axi.set_xticks([]) + axi.set_yticks([]) + fig.savefig("debug/cca_before_after.png") + plt.close(fig) + + model.train() + return query_pred, query_pred_logits + + +def validation_on_scans(model, curr_lb, support_images, support_fg_mask, support_bg_mask, testloader, te_parent, te_dataset, _config, sup_img_indx=1, save_pred_buffer=None): + if save_pred_buffer is None: + save_pred_buffer = {} + lb_buffer = {} + conf_buffer = {} + # sup_img_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][sup_img_indx]]] # way(1) x shot x [B(1) x C x H x W] + # sup_fgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][sup_img_indx]]] + # sup_bgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][sup_img_indx]]] + for scan_idx, sample_batched in enumerate(testloader): + print(f"Processing scan: {scan_idx + 1} / {len(testloader)}") + _scan_id = sample_batched["scan_id"][0] + if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query + print(f"Skipping support scan: {_scan_id}") # TODO delete + continue + + outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] + outsize = (_config['input_size'][0], _config['input_size'][1], outsize[0]) # original image read by itk: Z, H, W, in prediction we use H, W, Z + _pred = np.zeros( outsize ) + _pred.fill(np.nan) + conf_buffer[_scan_id] = [] + + query_images = sample_batched['image'].cuda() + z_min = sample_batched['z_min'][0] + z_max = sample_batched['z_max'][0] + # create an index list that starts with s_idx goes down to 0, then concat the indices from s_idx + 1 to the end + # this is to make sure that the most similiar slice is the first one to be processed + indices = list(range(len(query_images[0]))) + qpart = sup_img_indx + for idx, i in enumerate(tqdm(indices)): + if _config["use_3_slices"]: + # change the query to 3 slices (-1, 0, 1) + if i == 0: + prev_q = torch.zeros_like(query_images[0, i]).unsqueeze(0) + else: + prev_q = query_images[0, i - 1].unsqueeze(0) + if i == len(query_images[0]) - 1: + next_q = torch.zeros_like(query_images[0, i]).unsqueeze(0) + else: + next_q = query_images[0, i + 1].unsqueeze(0) + + query = torch.cat([prev_q, query_images[0, i].unsqueeze(0), next_q], dim=1) + + else: + query = query_images[0, i].unsqueeze(0) + + + query_pred, query_pred_logits = validation_single_slice(model, support_images, support_fg_mask, support_bg_mask, [query], _config, q_part=qpart) + query_conf = get_confidence_from_logits(query_pred_logits, query_pred) + conf_buffer[_scan_id].append(query_conf) + _pred[..., i] = query_pred.copy() + + if _config['dataset'] != 'C0': + lb_buffer[_scan_id] = _pred.transpose(2,0,1) + else: + lb_buffer[_scan_id] = _pred + save_pred_buffer[str(curr_lb)] = lb_buffer + + return save_pred_buffer, conf_buffer + + + +def validation(model, curr_lb, testloader, te_parent, te_dataset, _config, support_images, support_fg_mask, support_bg_mask, mar_val_metric_node=None, save_pred_buffer=None, do_validation=False, get_confidence=False): + model.eval() + with torch.no_grad(): + curr_scan_count = -1 # counting for current scan + _lb_buffer = {} # indexed by scan + _conf_buffer = {} # indexed by scan + _has_label_buffer = {} # indexed by scan + last_qpart = 0 # used as indicator for adding result to buffer + + for idx, sample_batched in enumerate(tqdm(testloader)): + _scan_id = sample_batched["scan_id"][0] # we assume batch size for query is 1 + if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query + continue + if sample_batched["is_start"]: + ii = 0 + curr_scan_count += 1 + if do_validation: + if curr_scan_count > 0: + break + print(f"Processing scan {curr_scan_count + 1} / {len(te_dataset.dataset.pid_curr_load)}") + _scan_id = sample_batched["scan_id"][0] + outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] + outsize = (te_dataset.dataset.image_size, te_dataset.dataset.image_size, outsize[0]) # original image read by itk: Z, H, W, in prediction we use H, W, Z + _pred = np.zeros( outsize ) + _pred.fill(np.nan) + _conf_buffer[_scan_id] = [] + _has_label_buffer[_scan_id] = [] + + q_part = sample_batched["part_assign"] # the chunck of query, for assignment with support + query_images = [sample_batched['image'].cuda()] + query_labels = torch.cat([ sample_batched['label'].cuda()], dim=0) + # if not 1 in query_labels: + # continue + # [way, [part, [shot x C x H x W]]] -> + query_pred, query_pred_logits = validation_single_slice(model, support_images, support_fg_mask, support_bg_mask, query_images, _config, q_part=q_part) + _pred[..., ii] = query_pred.copy() + if 1 in query_labels: + _has_label_buffer[_scan_id].append(True) + else: + _has_label_buffer[_scan_id].append(False) + + if get_confidence: + # calc condfidence from logits and log it in the _conf_buffer + query_conf = get_confidence_from_logits(query_pred_logits, query_pred) + _conf_buffer[_scan_id].append(query_conf) + + if mar_val_metric_node is not None and ((sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin'])): + mar_val_metric_node.record(query_pred, np.array(query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) + else: + pass + + ii += 1 + # now check data format + if sample_batched["is_end"]: + if _config['dataset'] != 'C0': + _lb_buffer[_scan_id] = _pred.transpose(2,0,1) # H, W, Z -> to Z H W + else: + _lb_buffer[_scan_id] = _pred + + save_pred_buffer[str(curr_lb)] = _lb_buffer + + model.train() + + return save_pred_buffer, _conf_buffer, _has_label_buffer + + +def load_config_from_url(url: str) -> str: + with urllib.request.urlopen(url) as f: + return f.read().decode() + + +def save_pred_gt_fig(query_images, query_pred, query_labels, support_images=None, support_labels=None, path="debug/gt_vs_pred.png"): + fig = plt.figure(figsize=(10, 5 if support_images is None else 10)) + ax1 = fig.add_subplot(2 if support_images is not None else 1, 2, 1) + ax1.imshow(query_images[0][0, 1].cpu().numpy()) + ax1.imshow(query_labels[0].cpu().numpy(), alpha=0.5) + ax1.set_title("Ground Truth") + ax2 = fig.add_subplot(2 if support_images is not None else 1, 2, 2) + ax2.imshow(query_images[0][0, 1].cpu().numpy()) + ax2.imshow(query_pred, alpha=0.5) + ax2.set_title("Prediction") + if support_images is not None: + ax3 = fig.add_subplot(2, 2, 3) + ax3.imshow(support_images[0][0, 1].cpu().numpy()) + ax3.imshow(support_labels[0].cpu().numpy(), alpha=0.5) + ax3.set_title("Support") + plt.savefig(path) + plt.close('all') + + +def plot_heatmap_of_probs(probs, image, path=None): + # normalize image values to be between 0 and 1, assume image doesnt have a specific range + image = (image - image.min()) / (image.max() - image.min()) + rgb_image = np.repeat(image[:, :, np.newaxis], 3, axis=2) + # Create a 3D figure + fig = plt.figure() + ax = fig.add_subplot(111) + ax.imshow(rgb_image) + ax.imshow(probs, alpha=0.5) + if path is not None: + fig.savefig(path) + else: + plt.show() + plt.close(fig) + + +def plot_3d_bar_probabilities(probabilities, labels, image, path=None): + # Create a 3D figure + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + # Create a meshgrid of the x and y coordinates + x, y = np.meshgrid(np.arange(probabilities.shape[1]), np.arange(probabilities.shape[0])) + + # Flatten the probabilities and labels data and convert them to 1D arrays + z = probabilities.flatten() + c = np.where(labels.flatten() == 1, 'g', 'r') + + # normaliize image values to be between 0 and 1, assume image doesnt have a specific range + image = (image - image.min()) / (image.max() - image.min()) + rgb_image = np.repeat(image[:, :, np.newaxis], 3, axis=2) + # ax.imshow(rgb_image, extent=[0, probabilities.shape[1], 0, probabilities.shape[0]], alpha=0.5) + + # Create the 3D bar plot + ax.plot_surface(x, y, np.zeros_like(x), facecolors=rgb_image) + ax.bar3d(x.ravel(), y.ravel(), np.zeros_like(z), 1, 1, z, color=c, alpha=0.3) + + # Set the axis labels + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Probability') + + # Show the plot + if path is not None: + fig.savefig(path) + else: + plt.show() + plt.close(fig) + +# def plot_3d_bar_probabilities(probabilities, labels, path=None): +# # Create a 3D figure +# fig = plt.figure() +# ax = fig.add_subplot(111, projection='3d') + +# # Create a meshgrid of the x and y coordinates +# x, y = np.meshgrid(np.arange(probabilities.shape[1]), np.arange(probabilities.shape[0])) + +# # Flatten the probabilities and labels data and convert them to 1D arrays +# z = probabilities.flatten() +# c = np.where(labels.flatten() == 1, 'g', 'r') + +# # Create the 3D bar plot +# ax.bar3d(x.ravel(), y.ravel(), np.zeros_like(z), 1, 1, z, color=c) + +# # Set the axis labels +# ax.set_xlabel('X') +# ax.set_ylabel('Y') +# ax.set_zlabel('Probability') + +# # Show the plot +# if path is not None: +# fig.savefig(path) +# else: +# plt.show() +# plt.close(fig) + + +# def sliding_window_confidence_segmentation(query_pred_conf:np.array, window_size=3, threshold=0.5): +# """ +# query_pred_conf: np.array, shape (B, H, W) +# """ +# # slice window across the query_pred_conf, if the window has a mean confidence > 0.5, the center pixel is 1, otherwise 0 + +# pred = np.zeros_like(query_pred_conf) +# # slice the window +# for i in range(query_pred_conf.shape[-1] - window_size + 1): +# for j in range(query_pred_conf.shape[-2] - window_size + 1): +# window = query_pred_conf[:, i:i+window_size, j:j+window_size] +# if np.mean(window) > threshold: +# pred[:, i+window_size//2, j+window_size//2] = 1 + +# return pred + + +def sliding_window_confidence_segmentation(query_pred_conf: np.array, window_size=3, threshold=0.5): + """ + query_pred_conf: np.array, shape (B, H, W) + """ + B, H, W = query_pred_conf.shape + pad = window_size // 2 + padded_conf = np.pad(query_pred_conf, ((0, 0), (pad, pad), (pad, pad)), mode='constant') + + # Calculate the mean in sliding windows + window_view = np.lib.stride_tricks.sliding_window_view(padded_conf, (B, window_size, window_size)) + mean_values = np.mean(window_view, axis=(-1, -2)) + + pred = (mean_values > threshold).astype(int) + + return pred[..., 0] + + + +def get_confidence_from_logits(query_pred_logits: torch.Tensor): + query_probs = query_pred_logits.softmax(1)[:,1].flatten(1) + query_pred = query_probs.clone() + query_pred[query_probs < 0.5] = 0 + query_pred[query_probs >= 0.5] = 1 + return ((query_probs * query_pred).sum() / (query_pred.sum() + 1e-6)).item() + +def choose_threshold_kneedle(p): + ''' + p - probabilities of prediction + ''' + # use kneed to choose the threshold + # create pdf from x + n_bins = min(100, len(p)) + hist, bin_edges = np.histogram(p, bins=n_bins) + pdf = hist / hist.sum() + cdf = np.cumsum(pdf) + + x = np.linspace(0, 1, n_bins) + y = cdf + # plot x, y in a fig and save the fig + plt.figure() + plt.plot(x, y) + plt.savefig(f'debug/cdf.png') + plt.figure() + plt.plot(x, pdf) + plt.savefig(f'debug/pdf.png') + plt.close('all') + kneedle = kneed.KneeLocator(x, y, curve='convex', direction='increasing') + # get the value at the knee from the bin_edges + threshold = bin_edges[int(kneedle.knee * n_bins)] + + return threshold + + +def plot_cca_output(cca_output): + for j in range(cca_output[0]): + if j == 0: + continue + plt.figure() + plt.imshow(cca_output[1] == j) + plt.savefig(f'debug/cca_{j}.png') + plt.close('all') + + +def get_connected_components(query_pred_original, query_pred_logits, return_conf=False): + """ + get all connected components + """ + cca_output = cv2.connectedComponentsWithStats(query_pred_original.astype(np.uint8), connectivity=8) # TODO try 8 + + # plot_cca_output(cca_output) + + if return_conf: + # calc confidence for each connected component + cca_conf = {} # conf by id + query_probs = query_pred_logits.softmax(1)[:,1].cpu().detach().numpy() + for j in range(cca_output[0]): + if j == 0: + cca_conf[0] = 0 # background + continue + cca_conf[j] = ((query_probs.flatten() * (cca_output[1] == j).flatten()).sum() / ((query_pred_original.flatten().sum() + 1e-6))) # take into account the area of the connected component + + return cca_output, cca_conf + + return cca_output, None + +def cca(query_pred_original, query_pred_logits, return_conf=False, return_cc=False): + ''' + Performs connected component analysis on the query_pred and returns the most confident connected component + ''' + # cca_output = cv2.connectedComponentsWithStats(query_pred_original.astype(np.uint8), connectivity=8) # TODO try 8 + # # calc confidence for each connected component + # cca_conf = [] + # for j in range(cca_output[0]): + # if j == 0: + # cca_conf.append(0) # background + # continue + # cca_conf.append((query_pred_logits.softmax(1)[:,1].flatten(1).cpu().detach().numpy() * (cca_output[1] == j).flatten()).sum() / ((cca_output[1] == j).flatten().sum() + 1e-6) * ((cca_output[1] == j).flatten().sum() / (query_pred_original.flatten().sum() + 1e-6))) # take into account the area of the connected component + cca_output, cca_conf = get_connected_components(query_pred_original, query_pred_logits, return_conf=True) + + # find the most confident connected component, find max conf and its key + max_conf = cca_conf[0] + for k,v in cca_conf.items(): + if v > max_conf: + max_conf = v + max_key = k + + if max_conf == 0: + # no connected component found, use zeros + query_pred = np.zeros_like(query_pred_original) + else: + # zero out all other connected components + new_cca_output = list(cca_output) + new_cca_output[0] = 2 # bg + fg + new_cca_output[1] = np.where(cca_output[1] != max_key, 0, 1) # binarize the max_key + new_cca_output[2] = cca_output[2][[0, max_key]] + new_cca_output[3] = cca_output[3][[0, max_key]] + cca_output = tuple(new_cca_output) + + query_pred = (cca_output[1] == 1).astype(np.uint8) + # convert to binary mask + query_pred = (query_pred > 0).astype(np.uint8) + + if return_cc: + return cca_output + + query_pred_original = query_pred_original * query_pred + + if return_conf: + return query_pred_original, max_conf + + return query_pred_original + +def set_seed(seed): + """ + Set the random seed + """ + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +CLASS_LABELS = { + 'SABS': { + 'pa_all': set( [1,2,3,6] ), + 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing + 1: set( [2,3] ), # lower_abdomen + }, + 'C0': { + 'pa_all': set(range(1, 4)), + 0: set([2,3]), + 1: set([1,3]), + 2: set([1,2]), + }, + 'CHAOST2': { + 'pa_all': set(range(1, 5)), + 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes + 1: set([2, 3]), # lower_abdomen + }, +} + +def get_bbox(fg_mask, inst_mask): + """ + Get the ground truth bounding boxes + """ + + fg_bbox = torch.zeros_like(fg_mask, device=fg_mask.device) + bg_bbox = torch.ones_like(fg_mask, device=fg_mask.device) + + inst_mask[fg_mask == 0] = 0 + area = torch.bincount(inst_mask.view(-1)) + cls_id = area[1:].argmax() + 1 + cls_ids = np.unique(inst_mask)[1:] + + mask_idx = np.where(inst_mask[0] == cls_id) + y_min = mask_idx[0].min() + y_max = mask_idx[0].max() + x_min = mask_idx[1].min() + x_max = mask_idx[1].max() + fg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 1 + + for i in cls_ids: + mask_idx = np.where(inst_mask[0] == i) + y_min = max(mask_idx[0].min(), 0) + y_max = min(mask_idx[0].max(), fg_mask.shape[1] - 1) + x_min = max(mask_idx[1].min(), 0) + x_max = min(mask_idx[1].max(), fg_mask.shape[2] - 1) + bg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 0 + return fg_bbox, bg_bbox + +def t2n(img_t): + """ + torch to numpy regardless of whether tensor is on gpu or memory + """ + if img_t.is_cuda: + return img_t.data.cpu().numpy() + else: + return img_t.data.numpy() + +def to01(x_np): + """ + normalize a numpy to 0-1 for visualize + """ + return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-5) + +def compose_wt_simple(is_wce, data_name): + """ + Weights for cross-entropy loss + """ + # if is_wce: + # if data_name in ['SABS', 'SABS_Superpix', 'SABS_448', 'SABS_Superpix_448', 'SABS_672', 'SABS_Superpix_672','C0', 'C0_Superpix', 'CHAOST2', 'CHAOST2_Superpix', 'CHAOST2_672', 'CHAOST2_Superpix_672', 'LITS17', 'LITS17_Superpix']: + # return torch.FloatTensor([0.05, 1.0]).cuda() + # else: + # raise NotImplementedError + # else: + # return torch.FloatTensor([1.0, 1.0]).cuda() + return torch.FloatTensor([0.05, 1.0]).cuda() + + +class CircularList(list): + """ + Helper for spliting training and validation scans + Originally: https://stackoverflow.com/questions/8951020/pythonic-circular-list/8951224 + """ + def __getitem__(self, x): + if isinstance(x, slice): + return [self[x] for x in self._rangeify(x)] + + index = operator.index(x) + try: + return super().__getitem__(index % len(self)) + except ZeroDivisionError: + raise IndexError('list index out of range') + + def _rangeify(self, slice): + start, stop, step = slice.start, slice.stop, slice.step + if start is None: + start = 0 + if stop is None: + stop = len(self) + if step is None: + step = 1 + return range(start, stop, step) + diff --git a/validation.py b/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..ce406e52e10770869613c1ed4fa8ae6a3608f9f5 --- /dev/null +++ b/validation.py @@ -0,0 +1,367 @@ +""" +Validation script +""" +import os +import shutil +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import MultiStepLR +import torch.backends.cudnn as cudnn +import numpy as np +import matplotlib.pyplot as plt +from models.grid_proto_fewshot import FewShotSeg +from dataloaders.dev_customized_med import med_fewshot_val + +from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset +from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset +from dataloaders.dataset_utils import DATASET_INFO, get_normalize_op +from dataloaders.niftiio import convert_to_sitk +import dataloaders.augutils as myaug + +from util.metric import Metric +from util.consts import IMG_SIZE +from util.utils import cca, sliding_window_confidence_segmentation, plot_3d_bar_probabilities, save_pred_gt_fig, plot_heatmap_of_probs +from config_ssl_upload import ex + +from tqdm import tqdm +import SimpleITK as sitk +from torchvision.utils import make_grid +from tqdm.auto import tqdm + +from util.utils import set_seed, t2n, to01, compose_wt_simple +# config pre-trained model caching path +os.environ['TORCH_HOME'] = "./pretrained_model" + + +def test_time_training(_config, model, image, prediction): + model.train() + data_name = _config['dataset'] + my_weight = compose_wt_simple(_config["use_wce"], data_name) + criterion = nn.CrossEntropyLoss( + ignore_index=_config['ignore_label'], weight=my_weight) + if _config['optim_type'] == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) + elif _config['optim_type'] == 'adam': + optimizer = torch.optim.AdamW( + model.parameters(), lr=_config['lr'], eps=1e-5) + else: + raise NotImplementedError + optimizer.zero_grad() + scheduler = MultiStepLR( + optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma']) + + tr_transforms = myaug.transform_with_label( + {'aug': myaug.get_aug(_config['which_aug'], _config['input_size'][0])}) + + comp = np.concatenate([image.transpose(1, 2, 0), prediction[None,...].transpose(1,2,0)], axis= -1) + print("Test Time Training...") + pbar = tqdm(range(_config['n_steps'])) + for idx in pbar: + query_image, query_label = tr_transforms(comp, c_img=image.shape[0], c_label=1, nclass=2, use_onehot=False) + support_image, support_label = tr_transforms(comp, c_img=image.shape[0], c_label=1, nclass=2, use_onehot=False) + query_label = torch.from_numpy(query_label.transpose(2,1,0)).cuda().long() + + query_images = [torch.from_numpy(query_image.transpose(2, 1, 0)).unsqueeze(0).cuda().float().requires_grad_(True)] + support_fg_mask = [[torch.from_numpy(support_label.transpose(2, 1, 0)).cuda().float().requires_grad_(True)]] + support_bg_mask = [[torch.from_numpy(1 - support_label.transpose(2, 1, 0)).cuda().float().requires_grad_(True)]] + support_images = [[torch.from_numpy(support_image.transpose(2, 1, 0)).unsqueeze(0).cuda().float().requires_grad_(True)]] + + # fig, ax = plt.subplots(1, 2) + # ax[0].imshow(query_images[0][0,0].cpu().numpy()) + # ax[1].imshow(support_image[...,0]) + # ax[1].imshow(support_label[...,0], alpha=0.5) + # fig.savefig("debug/query_support_ttt.png") + out = model(support_images, support_fg_mask, support_bg_mask, query_images, isval=False, val_wsize=None) + query_pred, align_loss, _, _, _, _, _ = out + # fig, ax = plt.subplots(1, 2) + # pred = np.array(query_pred.argmax(dim=1)[0].cpu()) + # ax[0].imshow(query_images[0][0,0].cpu().numpy()) + # ax[0].imshow(pred, alpha=0.5) + # ax[1].imshow(support_image[...,0]) + # ax[1].imshow(support_label[...,0], alpha=0.5) + # fig.savefig("debug/ttt.png") + loss = 0.0 + loss += criterion(query_pred.float(), query_label.long()) + loss += align_loss + loss.backward() + + if (idx + 1) % _config['grad_accumulation_steps'] == 0: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + pbar.set_postfix(loss=f"{loss.item():.4f}") + model.eval() + return model + + +@ex.automain +def main(_run, _config, _log): + if _run.observers: + os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) + for source_file, _ in _run.experiment_info['sources']: + os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), + exist_ok=True) + _run.observers[0].save_file(source_file, f'source/{source_file}') + shutil.rmtree(f'{_run.observers[0].basedir}/_sources') + + torch.cuda.set_device(device=_config['gpu_id']) + torch.set_num_threads(1) + + _log.info(f'###### Reload model {_config["reload_model_path"]} ######') + model = FewShotSeg(image_size=_config['input_size'][0], + pretrained_path=_config['reload_model_path'], cfg=_config['model']) + + model = model.cuda() + model.eval() + + _log.info('###### Load data ######') + # Training set + data_name = _config['dataset'] + if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672': + baseset_name = 'SABS' + max_label = 13 + elif data_name == 'C0_Superpix': + raise NotImplementedError + baseset_name = 'C0' + max_label = 3 + elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672': + baseset_name = 'CHAOST2' + max_label = 4 + elif 'lits' in data_name.lower(): + baseset_name = 'LITS17' + max_label = 4 + else: + raise ValueError(f'Dataset: {data_name} not found') + + test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - \ + DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]] + + + _log.info( + f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######') + _log.info( + f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######') + + if baseset_name == 'SABS': + tr_parent = SuperpixelDataset( # base dataset + which_dataset=baseset_name, + base_dir=_config['path'][data_name]['data_dir'], + idx_split=_config['eval_fold'], + mode='val', # 'train', + # dummy entry for superpixel dataset + min_fg=str(_config["min_fg_data"]), + image_size=_config['input_size'][0], + transforms=None, + nsup=_config['task']['n_shots'], + scan_per_load=_config['scan_per_load'], + exclude_list=_config["exclude_cls_list"], + superpix_scale=_config["superpix_scale"], + fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or ( + data_name == 'CHAOST2_Superpix') else None, + use_clahe=_config['use_clahe'], + norm_mean=0.18792 * 256 if baseset_name == 'LITS17' else None, + norm_std=0.25886 * 256 if baseset_name == 'LITS17' else None + ) + norm_func = tr_parent.norm_func + else: + norm_func = get_normalize_op(modality='MR', fids=None) + + te_dataset, te_parent = med_fewshot_val( + dataset_name=baseset_name, + base_dir=_config['path'][data_name]['data_dir'], + idx_split=_config['eval_fold'], + scan_per_load=_config['scan_per_load'], + act_labels=test_labels, + npart=_config['task']['npart'], + nsup=_config['task']['n_shots'], + extern_normalize_func=norm_func, + image_size=_config["input_size"][0], + use_clahe=_config['use_clahe'], + use_3_slices=_config["use_3_slices"] + ) + + # dataloaders + testloader = DataLoader( + te_dataset, + batch_size=1, + shuffle=False, + num_workers=1, + pin_memory=False, + drop_last=False + ) + + _log.info('###### Set validation nodes ######') + mar_val_metric_node = Metric(max_label=max_label, n_scans=len( + te_dataset.dataset.pid_curr_load) - _config['task']['n_shots']) + + _log.info('###### Starting validation ######') + mar_val_metric_node.reset() + if _config["sliding_window_confidence_segmentation"]: + print("Using sliding window confidence segmentation") # TODO delete this + + save_pred_buffer = {} # indexed by class + + for curr_lb in test_labels: + te_dataset.set_curr_cls(curr_lb) + support_batched = te_parent.get_support(curr_class=curr_lb, class_idx=[ + curr_lb], scan_idx=_config["support_idx"], npart=_config['task']['npart']) + + # way(1 for now) x part x shot x 3 x H x W] # + support_images = [[shot.cuda() for shot in way] + for way in support_batched['support_images']] # way x part x [shot x C x H x W] + suffix = 'mask' + support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way] + for way in support_batched['support_mask']] + support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way] + for way in support_batched['support_mask']] + + curr_scan_count = -1 # counting for current scan + _lb_buffer = {} # indexed by scan + _lb_vis_buffer = {} + + last_qpart = 0 # used as indicator for adding result to buffer + + for idx, sample_batched in enumerate(tqdm(testloader)): + # we assume batch size for query is 1 + _scan_id = sample_batched["scan_id"][0] + if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query + continue + if sample_batched["is_start"]: + ii = 0 + curr_scan_count += 1 + print( + f"Processing scan {curr_scan_count + 1} / {len(te_dataset.dataset.pid_curr_load)}") + _scan_id = sample_batched["scan_id"][0] + outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] + # original image read by itk: Z, H, W, in prediction we use H, W, Z + outsize = (_config['input_size'][0], + _config['input_size'][1], outsize[0]) + _pred = np.zeros(outsize) + _pred.fill(np.nan) + # assign proto shows in the query image which proto is assigned to each pixel, proto_grid is the ids of the prototypes in the support image used, support_images are the 3 support images, support_img_parts are the parts of the support images used for each query image + _vis = {'assigned_proto': [None] * _pred.shape[-1], 'proto_grid': [None] * _pred.shape[-1], + 'support_images': support_images, 'support_img_parts': [None] * _pred.shape[-1]} + + # the chunck of query, for assignment with support + q_part = sample_batched["part_assign"] + query_images = [sample_batched['image'].cuda()] + query_labels = torch.cat( + [sample_batched['label'].cuda()], dim=0) + if 1 not in query_labels and not sample_batched["is_end"] and _config["skip_no_organ_slices"]: + ii += 1 + continue + # [way, [part, [shot x C x H x W]]] -> + # way(1) x shot x [B(1) x C x H x W] + sup_img_part = [[shot_tensor.unsqueeze( + 0) for shot_tensor in support_images[0][q_part]]] + sup_fgm_part = [[shot_tensor.unsqueeze( + 0) for shot_tensor in support_fg_mask[0][q_part]]] + sup_bgm_part = [[shot_tensor.unsqueeze( + 0) for shot_tensor in support_bg_mask[0][q_part]]] + + # query_pred_logits, _, _, assign_mats, proto_grid, _, _ = model( + # sup_img_part, sup_fgm_part, sup_bgm_part, query_images, isval=True, val_wsize=_config["val_wsize"], show_viz=True) + with torch.no_grad(): + out = model(sup_img_part, sup_fgm_part, sup_bgm_part, + query_images, isval=True, val_wsize=_config["val_wsize"]) + query_pred_logits, _, _, assign_mats, proto_grid, _, _ = out + pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu()) + + if _config["ttt"]: + state_dict = model.state_dict() + model = test_time_training(_config, model, sample_batched['image'].numpy()[0], pred) + out = model(sup_img_part, sup_fgm_part, sup_bgm_part, + query_images, isval=True, val_wsize=_config["val_wsize"]) + query_pred_logits, _, _, assign_mats, proto_grid, _, _ = out + pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu()) + if _config["reset_after_slice"]: + model.load_state_dict(state_dict) + + query_pred = query_pred_logits.argmax(dim=1).cpu() + query_pred = F.interpolate(query_pred.unsqueeze( + 0).float(), size=query_labels.shape[-2:], mode='nearest').squeeze(0).long().numpy()[0] + + if _config["debug"]: + save_pred_gt_fig(query_images, query_pred, query_labels, sup_img_part[0], sup_fgm_part[0][0], + f'debug/preds/scan_{_scan_id}_label_{curr_lb}_{idx}_gt_vs_pred.png') + + if _config['do_cca']: + query_pred = cca(query_pred, query_pred_logits) + if _config["debug"]: + save_pred_gt_fig(query_images, query_pred, query_labels, + f'debug/scan_{_scan_id}_label_{curr_lb}_{idx}_gt_vs_pred_after_cca.png') + + _pred[..., ii] = query_pred.copy() + # _vis['assigned_proto'][ii] = assign_mats + # _vis['proto_grid'][ii] = proto_grid.cpu() + # proto_ids = torch.unique(proto_grid) + # _vis['support_img_parts'][ii] = q_part + + if (sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin']) and not sample_batched["is_end"]: + mar_val_metric_node.record(query_pred, np.array( + query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) + else: + pass + + ii += 1 + # now check data format + if sample_batched["is_end"]: + if _config['dataset'] != 'C0': + _lb_buffer[_scan_id] = _pred.transpose( + 2, 0, 1) # H, W, Z -> to Z H W + else: + _lb_buffer[_scan_id] = _pred + # _lb_vis_buffer[_scan_id] = _vis + + save_pred_buffer[str(curr_lb)] = _lb_buffer + + # save results + for curr_lb, _preds in save_pred_buffer.items(): + for _scan_id, _pred in _preds.items(): + _pred *= float(curr_lb) + itk_pred = convert_to_sitk( + _pred, te_dataset.dataset.info_by_scan[_scan_id]) + fid = os.path.join( + f'{_run.observers[0].dir}/interm_preds', f'scan_{_scan_id}_label_{curr_lb}.nii.gz') + sitk.WriteImage(itk_pred, fid, True) + _log.info(f'###### {fid} has been saved ######') + + + # compute dice scores by scan + m_classDice, _, m_meanDice, _, m_rawDice = mar_val_metric_node.get_mDice( + labels=sorted(test_labels), n_scan=None, give_raw=True) + + m_classPrec, _, m_meanPrec, _, m_classRec, _, m_meanRec, _, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall( + labels=sorted(test_labels), n_scan=None, give_raw=True) + + mar_val_metric_node.reset() # reset this calculation node + + # write validation result to log file + _run.log_scalar('mar_val_batches_classDice', m_classDice.tolist()) + _run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist()) + _run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist()) + + _run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist()) + _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist()) + _run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist()) + + _run.log_scalar('mar_val_batches_classRec', m_classRec.tolist()) + _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist()) + _run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist()) + + _log.info(f'mar_val batches classDice: {m_classDice}') + _log.info(f'mar_val batches meanDice: {m_meanDice}') + + _log.info(f'mar_val batches classPrec: {m_classPrec}') + _log.info(f'mar_val batches meanPrec: {m_meanPrec}') + + _log.info(f'mar_val batches classRec: {m_classRec}') + _log.info(f'mar_val batches meanRec: {m_meanRec}') + + print("============ ============") + + _log.info(f'End of validation') + return 1 diff --git a/validation_protosam.py b/validation_protosam.py new file mode 100644 index 0000000000000000000000000000000000000000..9697662712741018d2fa8c5cabb84cabebac6e2c --- /dev/null +++ b/validation_protosam.py @@ -0,0 +1,586 @@ +""" +Validation script +""" +import math +import os +import pandas as pd +import csv +import shutil +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.transforms as transforms +import torchvision.transforms.functional as F +from torch.utils.data import DataLoader +import torch.backends.cudnn as cudnn +import numpy as np +import time +import matplotlib.pyplot as plt +from models.ProtoSAM import ProtoSAM, ALPNetWrapper, SamWrapperWrapper, InputFactory, ModelWrapper, TYPE_ALPNET, TYPE_SAM +from models.ProtoMedSAM import ProtoMedSAM +from models.grid_proto_fewshot import FewShotSeg +from models.segment_anything.utils.transforms import ResizeLongestSide +from models.SamWrapper import SamWrapper +# from dataloaders.PolypDataset import get_polyp_dataset, get_vps_easy_unseen_dataset, get_vps_hard_unseen_dataset, PolypDataset, KVASIR, CVC300, COLON_DB, ETIS_DB, CLINIC_DB +from dataloaders.PolypDataset import get_polyp_dataset, PolypDataset +from dataloaders.PolypTransforms import get_polyp_transform +from dataloaders.SimpleDataset import SimpleDataset +from dataloaders.ManualAnnoDatasetv2 import get_nii_dataset +from dataloaders.common import ValidationDataset +from config_ssl_upload import ex + +import tqdm +from tqdm.auto import tqdm +import cv2 +from collections import defaultdict + +# config pre-trained model caching path +os.environ['TORCH_HOME'] = "./pretrained_model" + +# Supported Datasets +CHAOS = "chaos" +SABS = "sabs" +POLYPS = "polyps" + +ALP_DS = [CHAOS, SABS] + +ROT_DEG = 0 + +def get_bounding_box(segmentation_map): + """Generate bounding box from a segmentation map. one bounding box to include the extreme points of the segmentation map.""" + if isinstance(segmentation_map, torch.Tensor): + segmentation_map = segmentation_map.cpu().numpy() + + bbox = cv2.boundingRect(segmentation_map.astype(np.uint8)) + # plot bounding boxes for each contours + # plt.figure() + # x, y, w, h = bbox + # plt.imshow(segmentation_map) + # plt.gca().add_patch(plt.Rectangle((x, y), w, h, fill=False, edgecolor='r', linewidth=2)) + # plt.savefig("debug/bounding_boxes.png") + + return bbox + +def calc_iou(boxA, boxB): + """ + boxA: [x, y, w, h] + """ + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2]) + yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3]) + + interArea = max(0, xB - xA) * max(0, yB - yA) + boxAArea = boxA[2] * boxA[3] + boxBArea = boxB[2] * boxB[3] + + iou = interArea / float(boxAArea + boxBArea - interArea) + return iou + + +def eval_detection(pred_list): + """ + pred_list: list of dictionaries with keys 'pred_bbox', 'gt_bbox' and score (prediction confidence score). + compute AP50, AP75, AP50:95:10 + """ + iou_thresholds = np.round(np.arange(0.5, 1.0, 0.05), 2) + ap_dict = {iou: [] for iou in iou_thresholds} + for iou_threshold in iou_thresholds: + tp, fp = 0, 0 + + for pred in pred_list: + pred_bbox = pred['pred_bbox'] + gt_bbox = pred['gt_bbox'] + + iou = calc_iou(pred_bbox, gt_bbox) + + if iou >= iou_threshold: + tp += 1 + else: + fp += 1 + + precision = tp / (tp + fp) + recall = tp / len(pred_list) + f1 = 2 * (precision * recall) / (precision + recall) + + ap_dict[iou_threshold] = { + 'iou_threshold': iou_threshold, + 'tp': tp, + 'fp': fp, + 'n_gt': len(pred_list), + 'f1': f1, + 'precision': precision, + 'recall': recall + } + + # Convert results to a DataFrame and save to CSV + results = [] + for iou_threshold in iou_thresholds: + results.append(ap_dict[iou_threshold]) + + df = pd.DataFrame(results) + return df + + +def plot_pred_gt_support(query_image, pred, gt, support_images, support_masks, score=None, save_path="debug/pred_vs_gt"): + """ + Save 5 key images: support images, support mask, query, ground truth and prediction. + Handles both grayscale and RGB images consistently with the same mask color. + + Args: + query_image: Query image tensor (grayscale or RGB) + pred: 2d tensor where 1 represents foreground and 0 represents background + gt: 2d tensor where 1 represents foreground and 0 represents background + support_images: Support image tensors (grayscale or RGB) + support_masks: Support mask tensors + score: Optional score to add to filename + save_path: Base path without extension for saving images + """ + # Create directory for this case + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # Process query image - ensure HxWxC format for visualization + query_image = query_image.clone().detach().cpu() + if len(query_image.shape) == 3 and query_image.shape[0] <= 3: # CHW format + query_image = query_image.permute(1, 2, 0) + + # Handle grayscale vs RGB consistently + if len(query_image.shape) == 2 or (len(query_image.shape) == 3 and query_image.shape[2] == 1): + # For grayscale, use cmap='gray' for visualization + is_grayscale = True + if len(query_image.shape) == 3: + query_image = query_image.squeeze(2) # Remove channel dimension for grayscale + else: + is_grayscale = False + + # Normalize image for visualization + query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8) + + # Convert pred and gt to numpy for visualization + pred_np = pred.cpu().float().numpy() # Ensure float before converting to numpy + gt_np = gt.cpu().float().numpy() # Ensure float before converting to numpy + + # Ensure binary masks + pred_np = (pred_np > 0).astype(np.float32) + gt_np = (gt_np > 0).astype(np.float32) + + # Set all positive values to 1.0 to ensure consistent red coloring in YlOrRd colormap + pred_np[pred_np > 0] = 1.0 + gt_np[gt_np > 0] = 1.0 + + # Create colormap for mask overlays - using the YlOrRd colormap as requested + mask_cmap = plt.cm.get_cmap('YlOrRd') + + # Generate color masks with alpha values + pred_rgba = mask_cmap(pred_np) + pred_rgba[..., 3] = pred_np * 0.7 # Last channel is alpha - semitransparent where mask=1 + + gt_rgba = mask_cmap(gt_np) + gt_rgba[..., 3] = gt_np * 0.7 # Last channel is alpha - semitransparent where mask=1 + + # 1. Save query image (original) + plt.figure(figsize=(10, 10)) + if is_grayscale: + plt.imshow(query_image, cmap='gray') + else: + plt.imshow(query_image) + plt.axis('off') + # Remove padding/whitespace + plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) + plt.savefig(f"{save_path}/query.png", bbox_inches='tight', pad_inches=0) + plt.close() + + # 2. Save query image with prediction overlay + plt.figure(figsize=(10, 10)) + if is_grayscale: + plt.imshow(query_image, cmap='gray') + else: + plt.imshow(query_image) + plt.imshow(pred_rgba) + plt.axis('off') + # Remove padding/whitespace + plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) + plt.savefig(f"{save_path}/pred.png", bbox_inches='tight', pad_inches=0) + plt.close() + + # 3. Save query image with ground truth overlay + plt.figure(figsize=(10, 10)) + if is_grayscale: + plt.imshow(query_image, cmap='gray') + else: + plt.imshow(query_image) + plt.imshow(gt_rgba) + plt.axis('off') + # Remove padding/whitespace + plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) + plt.savefig(f"{save_path}/gt.png", bbox_inches='tight', pad_inches=0) + plt.close() + + # Process and save support images and masks (just the first one for brevity) + if support_images is not None: + if isinstance(support_images, list): + support_images = torch.cat(support_images, dim=0).clone().detach() + if isinstance(support_masks, list): + support_masks = torch.cat(support_masks, dim=0).clone().detach() + + # Move to CPU for processing + support_images = support_images.cpu() + support_masks = support_masks.cpu() + + # Handle different dimensions of support images + if len(support_images.shape) == 4: # NCHW format + # Convert to NHWC for visualization + support_images = support_images.permute(0, 2, 3, 1) + + # Just process the first support image + i = 0 + if support_images.shape[0] > 0: + support_img = support_images[i].clone() + support_mask = support_masks[i].clone() + + # Check if grayscale or RGB + if support_img.shape[-1] == 1: # Last dimension is channels + support_img = support_img.squeeze(-1) # Remove channel dimension + support_is_gray = True + elif support_img.shape[-1] == 3: + support_is_gray = False + else: # Assume it's grayscale if not 1 or 3 channels + support_is_gray = True + + # Normalize support image + support_img = (support_img - support_img.min()) / (support_img.max() - support_img.min() + 1e-8) + + # 4. Save support image only + plt.figure(figsize=(10, 10)) + if support_is_gray: + plt.imshow(support_img, cmap='gray') + else: + plt.imshow(support_img) + plt.axis('off') + # Remove padding/whitespace + plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) + plt.savefig(f"{save_path}/support_1.png", bbox_inches='tight', pad_inches=0) + plt.close() + + # 5. Save support mask only (direct mask visualization similar to gt/pred) + plt.figure(figsize=(10, 10)) + + # Process support mask with same approach + support_mask_np = support_mask.cpu().float().numpy() + support_mask_np = (support_mask_np > 0).astype(np.float32) + support_mask_np[support_mask_np > 0] = 1.0 # Set to 1.0 for consistent coloring + + support_mask_rgba = mask_cmap(support_mask_np) + support_mask_rgba[..., 3] = support_mask_np * 0.7 # Last channel is alpha - semitransparent where mask=1 + + if is_grayscale: + plt.imshow(support_img, cmap='gray') + else: + plt.imshow(support_img) + plt.imshow(support_mask_rgba) + plt.axis('off') + # Remove padding/whitespace + plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) + plt.savefig(f"{save_path}/support_mask.png", bbox_inches='tight', pad_inches=0) + plt.close() + + + + +def get_dice_iou_precision_recall(pred: torch.Tensor, gt: torch.Tensor): + """ + pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background + gt: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background + """ + if gt.sum() == 0: + print("gt is all background") + return {"dice": 0, "precision": 0, "recall": 0} + + # Resize pred to match gt dimensions if they're different + if pred.shape != gt.shape: + print(f"Resizing prediction from {pred.shape} to match ground truth {gt.shape}") + # Use interpolate to resize pred to match gt dimensions + pred = torch.nn.functional.interpolate( + pred.unsqueeze(0).unsqueeze(0).float(), + size=gt.shape, + mode='nearest' + ).squeeze(0).squeeze(0) + + tp = (pred * gt).sum() + fp = (pred * (1 - gt)).sum() + fn = ((1 - pred) * gt).sum() + dice = 2 * tp / (2 * tp + fp + fn + 1e-8) + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + iou = tp / (tp + fp + fn + 1e-8) + return {"dice": dice, "iou": iou, "precision": precision, "recall": recall} + + +def get_alpnet_model(_config) -> ModelWrapper: + alpnet = FewShotSeg( + _config["input_size"][0], + _config["reload_model_path"], + _config["model"] + ) + alpnet.cuda() + alpnet_wrapper = ALPNetWrapper(alpnet) + + return alpnet_wrapper + +def get_sam_model(_config) -> ModelWrapper: + sam_args = { + "model_type": "vit_h", + "sam_checkpoint": "pretrained_model/sam_vit_h.pth" + } + sam = SamWrapper(sam_args=sam_args).cuda() + sam_wrapper = SamWrapperWrapper(sam) + return sam_wrapper + +def get_model(_config) -> ProtoSAM: + # Initial Segmentation Model + if _config["base_model"] == TYPE_ALPNET: + base_model = get_alpnet_model(_config) + else: + raise NotImplementedError(f"base model {_config['base_model']} not implemented") + + # ProtoSAM model + if _config["protosam_sam_ver"] in ("sam_h", "sam_b"): + sam_h_checkpoint = "pretrained_model/sam_vit_h.pth" + sam_b_checkpoint = "pretrained_model/sam_vit_b.pth" + sam_checkpoint = sam_h_checkpoint if _config["protosam_sam_ver"] == "sam_h" else sam_b_checkpoint + model = ProtoSAM(image_size = (1024, 1024), + coarse_segmentation_model=base_model, + use_bbox=_config["use_bbox"], + use_points=_config["use_points"], + use_mask=_config["use_mask"], + debug=_config["debug"], + num_points_for_sam=1, + use_cca=_config["do_cca"], + point_mode=_config["point_mode"], + use_sam_trans=True, + coarse_pred_only=_config["coarse_pred_only"], + sam_pretrained_path=sam_checkpoint, + use_neg_points=_config["use_neg_points"],) + elif _config["protosam_sam_ver"] == "medsam": + model = ProtoMedSAM(image_size = (1024, 1024), + coarse_segmentation_model=base_model, + debug=_config["debug"], + use_cca=_config["do_cca"], + ) + else: + raise NotImplementedError(f"protosam_sam_ver {_config['protosam_sam_ver']} not implemented") + + return model + + +def get_support_set_polyps(_config, dataset:PolypDataset): + n_support = _config["n_support"] + (support_images, support_labels, case) = dataset.get_support(n_support=n_support) + + return support_images, support_labels, case + + +def get_support_set_alpds(config, dataset:ValidationDataset): + support_set = dataset.get_support_set(config) + support_fg_masks = support_set["support_labels"] + support_images = support_set["support_images"] + support_scan_id = support_set["support_scan_id"] + return support_images, support_fg_masks, support_scan_id + + +def get_support_set(_config, dataset): + if _config["dataset"].lower() == POLYPS: + support_images, support_fg_masks, case = get_support_set_polyps(_config, dataset) + elif any(item in _config["dataset"].lower() for item in ALP_DS): + support_images, support_fg_masks, support_scan_id = get_support_set_alpds(_config, dataset) + else: + raise NotImplementedError(f"dataset {_config['dataset']} not implemented") + return support_images, support_fg_masks, support_scan_id + + +def update_support_set_by_scan_part(support_images, support_labels, qpart): + qpart_support_images = [support_images[qpart]] + qpart_support_labels = [support_labels[qpart]] + + return qpart_support_images, qpart_support_labels + + +def manage_support_sets(sample_batched, all_support_images, all_support_fg_mask, support_images, support_fg_mask, qpart=None): + if sample_batched['part_assign'][0] != qpart: + qpart = sample_batched['part_assign'][0] + support_images, support_fg_mask = update_support_set_by_scan_part(all_support_images, all_support_fg_mask, qpart) + + return support_images, support_fg_mask, qpart + + +@ex.automain +def main(_run, _config, _log): + if _run.observers: + os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) + for source_file, _ in _run.experiment_info['sources']: + os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), + exist_ok=True) + _run.observers[0].save_file(source_file, f'source/{source_file}') + print(f"####### created dir:{_run.observers[0].dir} #######") + shutil.rmtree(f'{_run.observers[0].basedir}/_sources') + print(f"config do_cca: {_config['do_cca']}, use_bbox: {_config['use_bbox']}") + cudnn.enabled = True + cudnn.benchmark = True + torch.cuda.set_device(device=_config['gpu_id']) + torch.set_num_threads(1) + + _log.info(f'###### Reload model {_config["reload_model_path"]} ######') + print(f'###### Reload model {_config["reload_model_path"]} ######') + model = get_model(_config) + model = model.to(torch.device("cuda")) + model.eval() + + sam_trans = ResizeLongestSide(1024) + if _config["dataset"].lower() == POLYPS: + tr_dataset, te_dataset = get_polyp_dataset(sam_trans=sam_trans, image_size=(1024, 1024)) + elif CHAOS in _config["dataset"].lower() or SABS in _config["dataset"].lower(): + tr_dataset, te_dataset = get_nii_dataset(_config, _config["input_size"][0]) + else: + raise NotImplementedError( + f"dataset {_config['dataset']} not implemented") + + # dataloaders + testloader = DataLoader( + te_dataset, + batch_size=1, + shuffle=False, + num_workers=1, + pin_memory=False, + drop_last=False + ) + + _log.info('###### Starting validation ######') + model.eval() + + mean_dice = [] + mean_prec = [] + mean_rec = [] + mean_iou = [] + + mean_dice_cases = {} + mean_iou_cases = {} + bboxes_w_scores = [] + + curr_case = None + supp_fts = None + qpart = None + support_images = support_fg_mask = None + all_support_images, all_support_fg_mask, support_scan_id = None, None, None + MAX_SUPPORT_IMAGES = 1 + is_alp_ds = any(item in _config["dataset"].lower() for item in ALP_DS) + is_polyp_ds = _config["dataset"].lower() == POLYPS + + if is_alp_ds: + all_support_images, all_support_fg_mask, support_scan_id = get_support_set(_config, te_dataset) + elif is_polyp_ds: + support_images, support_fg_mask, case = get_support_set_polyps(_config, tr_dataset) + + with tqdm(testloader) as pbar: + for idx, sample_batched in enumerate(tqdm(testloader)): + case = sample_batched['case'][0] + if is_alp_ds: + support_images, support_fg_mask, qpart = manage_support_sets( + sample_batched, + all_support_images, + all_support_fg_mask, + support_images, + support_fg_mask, + qpart, + ) + + if is_alp_ds and sample_batched["scan_id"][0] in support_scan_id: + continue + + query_images = sample_batched['image'].cuda() + query_labels = torch.cat([sample_batched['label']], dim=0) + if not 1 in query_labels and _config["skip_no_organ_slices"]: + continue + + n_try = 1 + with torch.no_grad(): + coarse_model_input = InputFactory.create_input( + input_type=_config["base_model"], + query_image=query_images, + support_images=support_images, + support_labels=support_fg_mask, + isval=True, + val_wsize=_config["val_wsize"], + original_sz=query_images.shape[-2:], + img_sz=query_images.shape[-2:], + gts=query_labels, + ) + coarse_model_input.to(torch.device("cuda")) + + query_pred, scores = model( + query_images, coarse_model_input, degrees_rotate=0) + query_pred = query_pred.cpu().detach() + + if _config["debug"]: + if is_alp_ds: + save_path = f'debug/preds/{case}_{sample_batched["z_id"].item()}_{idx}_{n_try}' + os.makedirs(save_path, exist_ok=True) + elif is_polyp_ds: + save_path = f'debug/preds/{case}_{idx}_{n_try}' + os.makedirs(save_path, exist_ok=True) + plot_pred_gt_support(query_images[0,0].cpu(), query_pred.cpu(), query_labels[0].cpu(), + support_images, support_fg_mask, save_path=save_path, score=scores[0]) + + # print(query_pred.shape) + # print(query_labels[0].shape) + metrics = get_dice_iou_precision_recall( + query_pred, query_labels[0].to(query_pred.device)) + mean_dice.append(metrics["dice"]) + mean_prec.append(metrics["precision"]) + mean_rec.append(metrics["recall"]) + mean_iou.append(metrics["iou"]) + + bboxes_w_scores.append({"pred_bbox": get_bounding_box(query_pred.cpu()), + "gt_bbox": get_bounding_box(query_labels[0].cpu()), + "score": np.mean(scores)}) + + if case not in mean_dice_cases: + mean_dice_cases[case] = [] + mean_iou_cases[case] = [] + mean_dice_cases[case].append(metrics["dice"]) + mean_iou_cases[case].append(metrics["iou"]) + + if metrics["dice"] < 0.6 and _config["debug"]: + path = f'{_run.observers[0].dir}/bad_preds/case_{case}_idx_{idx}_dice_{metrics["dice"]:.4f}' + if _config["debug"]: + path = f'debug/bad_preds/case_{case}_idx_{idx}_dice_{metrics["dice"]:.4f}' + os.makedirs(path, exist_ok=True) + print(f"saving bad prediction to {path}") + plot_pred_gt_support(query_images[0,0].cpu(), query_pred.cpu(), query_labels[0].cpu( + ), support_images, support_fg_mask, save_path=path, score=scores[0]) + + pbar.set_postfix_str({"mdice": f"{np.mean(mean_dice):.4f}", "miou": f"{np.mean(mean_iou):.4f}, n_try: {n_try}"}) + + + for k in mean_dice_cases.keys(): + _run.log_scalar(f'mar_val_batches_meanDice_{k}', np.mean(mean_dice_cases[k])) + _run.log_scalar(f'mar_val_batches_meanIOU_{k}', np.mean(mean_iou_cases[k])) + _log.info(f'mar_val batches meanDice_{k}: {np.mean(mean_dice_cases[k])}') + _log.info(f'mar_val batches meanIOU_{k}: {np.mean(mean_iou_cases[k])}') + + # write validation result to log file + m_meanDice = np.mean(mean_dice) + m_meanPrec = np.mean(mean_prec) + m_meanRec = np.mean(mean_rec) + m_meanIOU = np.mean(mean_iou) + + _run.log_scalar('mar_val_batches_meanDice', m_meanDice) + _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec) + _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec) + _run.log_scalar('mar_val_al_batches_meanIOU', m_meanIOU) + _log.info(f'mar_val batches meanDice: {m_meanDice}') + _log.info(f'mar_val batches meanPrec: {m_meanPrec}') + _log.info(f'mar_val batches meanRec: {m_meanRec}') + _log.info(f'mar_val batches meanIOU: {m_meanIOU}') + print("============ ============") + _log.info(f'End of validation') + return 1 \ No newline at end of file