diff --git a/.gitignore b/.gitignore index 61881b8a4f3eb751c0343b99778348ba0d08d1ac..cb32b8601ed8a2b36bf1c931fccc6689130c1eca 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ __pycache__/ !/input/example.png /models/ /temp/ -/custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs diff --git a/app.py b/app.py index 455f8bd78d2c6afc9512474bccc1f1fa5eb4d845..0aed6bee44c0185e68afe4c10621981adeb77d47 100644 --- a/app.py +++ b/app.py @@ -6,10 +6,10 @@ import gradio as gr import torch from huggingface_hub import hf_hub_download from nodes import NODE_CLASS_MAPPINGS -import spaces from comfy import model_management -@spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds +# import spaces +# @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: diff --git a/custom_nodes/ComfyUI-ReActor/.gitignore b/custom_nodes/ComfyUI-ReActor/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c7653ac8cb5181e73d71f2b6847dc4a073e9f792 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +*$py.class +.vscode/ +example +input diff --git a/custom_nodes/ComfyUI-ReActor/LICENSE b/custom_nodes/ComfyUI-ReActor/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/custom_nodes/ComfyUI-ReActor/README.md b/custom_nodes/ComfyUI-ReActor/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc0ffd57dca6969f9855ce66e1c2a5b1ddd5884d --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/README.md @@ -0,0 +1,488 @@ +
+ + logo + + ![Version](https://img.shields.io/badge/node_version-0.6.0_alpha1-lightgreen?style=for-the-badge&labelColor=darkgreen) + + + + + Support Me on Boosty +
+ + Support This Project + +
+ +
+ + [![Commit activity](https://img.shields.io/github/commit-activity/t/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/commits/main) + ![Last commit](https://img.shields.io/github/last-commit/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0) + [![Opened issues](https://img.shields.io/github/issues/Gourieff/ComfyUI-ReActor?color=red)](https://github.com/Gourieff/ComfyUI-ReActor/issues?cacheSeconds=0) + [![Closed issues](https://img.shields.io/github/issues-closed/Gourieff/ComfyUI-ReActor?color=green&cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/issues?q=is%3Aissue+state%3Aclosed) + ![License](https://img.shields.io/github/license/Gourieff/ComfyUI-ReActor) + + English | [Русский](/README_RU.md) + +# ReActor Nodes for ComfyUI
-=SFW-Friendly=- + +
+ +### The Fast and Simple Face Swap Extension Nodes for ComfyUI, based on [blocked ReActor](https://github.com/Gourieff/comfyui-reactor-node) - now it has a nudity detector to avoid using this software with 18+ content + +> By using this Node you accept and assume [responsibility](#disclaimer) + +
+ +--- +[**What's new**](#latestupdate) | [**Installation**](#installation) | [**Usage**](#usage) | [**Troubleshooting**](#troubleshooting) | [**Updating**](#updating) | [**Disclaimer**](#disclaimer) | [**Credits**](#credits) | [**Note!**](#note) + +--- + +
+ + + +## What's new in the latest update + +### 0.6.0 ALPHA1 + +- New Node `ReActorSetWeight` - you can now set the strength of face swap for `source_image` or `face_model` from 0% to 100% (in 12.5% step) + +
+0.6.0-whatsnew-01 +0.6.0-whatsnew-02 +0.6.0-whatsnew-03 +
+ +
+ Previous versions + +### 0.5.2 + +- ReSwapper models support. Although Inswapper still has the best similarity, but ReSwapper is evolving - thanks @somanchiu https://github.com/somanchiu/ReSwapper for the ReSwapper models and the ReSwapper project! This is a good step for the Community in the Inswapper's alternative creation! + +
+0.5.2-whatsnew-03 +0.5.2-whatsnew-04 +
+ +You can download ReSwapper models here: +https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models +Just put them into the "models/reswapper" directory. + +- NSFW-detector to not violate [GitHub rules](https://docs.github.com/en/site-policy/acceptable-use-policies/github-misinformation-and-disinformation#synthetic--manipulated-media-tools) +- New node "Unload ReActor Models" - is useful for complex WFs when you need to free some VRAM utilized by ReActor + +0.5.2-whatsnew-01 + +- Support of ORT CoreML and ROCM EPs, just install onnxruntime version you need +- Install script improvements to install latest versions of ORT-GPU + +
+0.5.2-whatsnew-02 +
+ +- Fixes and improvements + + +### 0.5.1 + +- Support of GPEN 1024/2048 restoration models (available in the HF dataset https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models) +- ReActorFaceBoost Node - an attempt to improve the quality of swapped faces. The idea is to restore and scale the swapped face (according to the `face_size` parameter of the restoration model) BEFORE pasting it to the target image (via inswapper algorithms), more information is [here (PR#321)](https://github.com/Gourieff/comfyui-reactor-node/pull/321) + +0.5.1-whatsnew-01 + +[Full size demo preview](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-02.png) + +- Sorting facemodels alphabetically +- A lot of fixes and improvements + +### [0.5.0 BETA4](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.5.0) + +- Spandrel lib support for GFPGAN + +### 0.5.0 BETA3 + +- Fixes: "RAM issue", "No detection" for MaskingHelper + +### 0.5.0 BETA2 + +- You can now build a blended face model from a batch of face models you already have, just add the "Make Face Model Batch" node to your workflow and connect several models via "Load Face Model" +- Huge performance boost of the image analyzer's module! 10x speed up! Working with videos is now a pleasure! + +0.5.0-whatsnew-05 + +### 0.5.0 BETA1 + +- SWAPPED_FACE output for the Masking Helper Node +- FIX: Empty A-channel for Masking Helper IMAGE output (causing errors with some nodes) was removed + +### 0.5.0 ALPHA1 + +- ReActorBuildFaceModel Node got "face_model" output to provide a blended face model directly to the main Node: + +Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v2.json) + +- Face Masking feature is available now, just add the "ReActorMaskHelper" Node to the workflow and connect it as shown below: + +0.5.0-whatsnew-01 + +If you don't have the "face_yolov8m.pt" Ultralytics model - you can download it from the [Assets](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) and put it into the "ComfyUI\models\ultralytics\bbox" directory +
+As well as ["sam_vit_b_01ec64.pth"](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth) model - download (if you don't have it) and put it into the "ComfyUI\models\sams" directory; + +Use this Node to gain the best results of the face swapping process: + +0.5.0-whatsnew-02 + +- ReActorImageDublicator Node - rather useful for those who create videos, it helps to duplicate one image to several frames to use them with VAE Encoder (e.g. live avatars): + +0.5.0-whatsnew-03 + +- ReActorFaceSwapOpt (a simplified version of the Main Node) + ReActorOptions Nodes to set some additional options such as (new) "input/source faces separate order". Yes! You can now set the order of faces in the index in the way you want ("large to small" goes by default)! + +0.5.0-whatsnew-04 + +- Little speed boost when analyzing target images (unfortunately it is still quite slow in compare to swapping and restoring...) + +### [0.4.2](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.2) + +- GPEN-BFR-512 and RestoreFormer_Plus_Plus face restoration models support + +You can download models here: https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models +
Put them into the `ComfyUI\models\facerestore_models` folder + +0.4.2-whatsnew-04 + +- Due to popular demand - you can now blend several images with persons into one face model file and use it with "Load Face Model" Node or in SD WebUI as well; + +Experiment and create new faces or blend faces of one person to gain better accuracy and likeness! + +Just add the ImpactPack's "Make Image Batch" Node as the input to the ReActor's one and load images you want to blend into one model: + +0.4.2-whatsnew-01 + +Result example (the new face was created from 4 faces of different actresses): + +0.4.2-whatsnew-02 + +Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v1.json) + +### [0.4.1](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.1) + +- CUDA 12 Support - don't forget to run (Windows) `install.bat` or (Linux/MacOS) `install.py` for ComfyUI's Python enclosure or try to install ORT-GPU for CU12 manually (https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x) +- Issue https://github.com/Gourieff/comfyui-reactor-node/issues/173 fix + +- Separate Node for the Face Restoration postprocessing (FR https://github.com/Gourieff/comfyui-reactor-node/issues/191), can be found inside ReActor's menu (RestoreFace Node) +- (Windows) Installation can be done for Python from the System's PATH +- Different fixes and improvements + +- Face Restore Visibility and CodeFormer Weight (Fidelity) options are now available! Don't forget to reload the Node in your existing workflow + +0.4.1-whatsnew-01 + +### [0.4.0](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.0) + +- Input "input_image" goes first now, it gives a correct bypass and also it is right to have the main input first; +- You can now save face models as "safetensors" files (`ComfyUI\models\reactor\faces`) and load them into ReActor implementing different scenarios and keeping super lightweight face models of the faces you use: + +0.4.0-whatsnew-01 +0.4.0-whatsnew-02 + +- Ability to build and save face models directly from an image: + +0.4.0-whatsnew-03 + +- Both the inputs are optional, just connect one of them according to your workflow; if both is connected - `image` has a priority. +- Different fixes making this extension better. + +Thanks to everyone who finds bugs, suggests new features and supports this project! + +
+ +## Installation + +
+ SD WebUI: AUTOMATIC1111 or SD.Next + +1. Close (stop) your SD-WebUI/Comfy Server if it's running +2. (For Windows Users): + - Install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Community version - you need this step to build Insightface) + - OR only [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and select "Desktop Development with C++" under "Workloads -> Desktop & Mobile" + - OR if you don't want to install VS or VS C++ BT - follow [this steps (sec. I)](#insightfacebuild) +3. Go to the `extensions\sd-webui-comfyui\ComfyUI\custom_nodes` +4. Open Console or Terminal and run `git clone https://github.com/Gourieff/ComfyUI-ReActor` +5. Go to the SD WebUI root folder, open Console or Terminal and run (Windows users)`.\venv\Scripts\activate` or (Linux/MacOS)`venv/bin/activate` +6. `python -m pip install -U pip` +7. `cd extensions\sd-webui-comfyui\ComfyUI\custom_nodes\ComfyUI-ReActor` +8. `python install.py` +9. Please, wait until the installation process will be finished +10. (From the version 0.3.0) Download additional facerestorers models from the link below and put them into the `extensions\sd-webui-comfyui\ComfyUI\models\facerestore_models` directory:
+https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models +11. Run SD WebUI and check console for the message that ReActor Node is running: +console_status_running + +12. Go to the ComfyUI tab and find there ReActor Node inside the menu `ReActor` or by using a search: +webui-demo +webui-demo + +
+ +
+ Standalone (Portable) ComfyUI for Windows + +1. Do the following: + - Install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Community version - you need this step to build Insightface) + - OR only [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and select "Desktop Development with C++" under "Workloads -> Desktop & Mobile" + - OR if you don't want to install VS or VS C++ BT - follow [this steps (sec. I)](#insightfacebuild) +2. Choose between two options: + - (ComfyUI Manager) Open ComfyUI Manager, click "Install Custom Nodes", type "ReActor" in the "Search" field and then click "Install". After ComfyUI will complete the process - please restart the Server. + - (Manually) Go to `ComfyUI\custom_nodes`, open Console and run `git clone https://github.com/Gourieff/ComfyUI-ReActor` +3. Go to `ComfyUI\custom_nodes\ComfyUI-ReActor` and run `install.bat` +4. If you don't have the "face_yolov8m.pt" Ultralytics model - you can download it from the [Assets](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) and put it into the "ComfyUI\models\ultralytics\bbox" directory
As well as one or both of "Sams" models from [here](https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/sams) - download (if you don't have them) and put into the "ComfyUI\models\sams" directory +5. Run ComfyUI and find there ReActor Nodes inside the menu `ReActor` or by using a search + +
+ +## Usage + +You can find ReActor Nodes inside the menu `ReActor` or by using a search (just type "ReActor" in the search field) + +List of Nodes: +- ••• Main Nodes ••• + - ReActorFaceSwap (Main Node) + - ReActorFaceSwapOpt (Main Node with the additional Options input) + - ReActorOptions (Options for ReActorFaceSwapOpt) + - ReActorFaceBoost (Face Booster Node) + - ReActorMaskHelper (Masking Helper) +- ••• Operations with Face Models ••• + - ReActorSaveFaceModel (Save Face Model) + - ReActorLoadFaceModel (Load Face Model) + - ReActorBuildFaceModel (Build Blended Face Model) + - ReActorMakeFaceModelBatch (Make Face Model Batch) +- ••• Additional Nodes ••• + - ReActorRestoreFace (Face Restoration) + - ReActorImageDublicator (Dublicate one Image to Images List) + - ImageRGBA2RGB (Convert RGBA to RGB) + +Connect all required slots and run the query. + +### Main Node Inputs + +- `input_image` - is an image to be processed (target image, analog of "target image" in the SD WebUI extension); + - Supported Nodes: "Load Image", "Load Video" or any other nodes providing images as an output; +- `source_image` - is an image with a face or faces to swap in the `input_image` (source image, analog of "source image" in the SD WebUI extension); + - Supported Nodes: "Load Image" or any other nodes providing images as an output; +- `face_model` - is the input for the "Load Face Model" Node or another ReActor node to provide a face model file (face embedding) you created earlier via the "Save Face Model" Node; + - Supported Nodes: "Load Face Model", "Build Blended Face Model"; + +### Main Node Outputs + +- `IMAGE` - is an output with the resulted image; + - Supported Nodes: any nodes which have images as an input; +- `FACE_MODEL` - is an output providing a source face's model being built during the swapping process; + - Supported Nodes: "Save Face Model", "ReActor", "Make Face Model Batch"; + +### Face Restoration + +Since version 0.3.0 ReActor Node has a buil-in face restoration.
Just download the models you want (see [Installation](#installation) instruction) and select one of them to restore the resulting face(s) during the faceswap. It will enhance face details and make your result more accurate. + +### Face Indexes + +By default ReActor detects faces in images from "large" to "small".
You can change this option by adding ReActorFaceSwapOpt node with ReActorOptions. + +And if you need to specify faces, you can set indexes for source and input images. + +Index of the first detected face is 0. + +You can set indexes in the order you need.
+E.g.: 0,1,2 (for Source); 1,0,2 (for Input).
This means: the second Input face (index = 1) will be swapped by the first Source face (index = 0) and so on. + +### Genders + +You can specify the gender to detect in images.
+ReActor will swap a face only if it meets the given condition. + +### Face Models + +Since version 0.4.0 you can save face models as "safetensors" files (stored in `ComfyUI\models\reactor\faces`) and load them into ReActor implementing different scenarios and keeping super lightweight face models of the faces you use. + +To make new models appear in the list of the "Load Face Model" Node - just refresh the page of your ComfyUI web application.
+(I recommend you to use ComfyUI Manager - otherwise you workflow can be lost after you refresh the page if you didn't save it before that). + +## Troubleshooting + + + +### **I. (For Windows users) If you still cannot build Insightface for some reasons or just don't want to install Visual Studio or VS C++ Build Tools - do the following:** + +1. (ComfyUI Portable) From the root folder check the version of Python:
run CMD and type `python_embeded\python.exe -V` +2. Download prebuilt Insightface package [for Python 3.10](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp310-cp310-win_amd64.whl) or [for Python 3.11](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl) (if in the previous step you see 3.11) or [for Python 3.12](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl) (if in the previous step you see 3.12) and put into the stable-diffusion-webui (A1111 or SD.Next) root folder (where you have "webui-user.bat" file) or into ComfyUI root folder if you use ComfyUI Portable +3. From the root folder run: + - (SD WebUI) CMD and `.\venv\Scripts\activate` + - (ComfyUI Portable) run CMD +4. Then update your PIP: + - (SD WebUI) `python -m pip install -U pip` + - (ComfyUI Portable) `python_embeded\python.exe -m pip install -U pip` +5. Then install Insightface: + - (SD WebUI) `pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (for 3.10) or `pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (for 3.11) or `pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12) + - (ComfyUI Portable) `python_embeded\python.exe -m pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (for 3.10) or `python_embeded\python.exe -m pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (for 3.11) or `python_embeded\python.exe -m pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12) +6. Enjoy! + +### **II. "AttributeError: 'NoneType' object has no attribute 'get'"** + +This error may occur if there's smth wrong with the model file `inswapper_128.onnx` + +Try to download it manually from [here](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) +and put it to the `ComfyUI\models\insightface` replacing existing one + +### **III. "reactor.execute() got an unexpected keyword argument 'reference_image'"** + +This means that input points have been changed with the latest update
+Remove the current ReActor Node from your workflow and add it again + +### **IV. ControlNet Aux Node IMPORT failed error when using with ReActor Node** + +1. Close ComfyUI if it runs +2. Go to the ComfyUI root folder, open CMD there and run: + - `python_embeded\python.exe -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless` + - `python_embeded\python.exe -m pip install opencv-python==4.7.0.72` +3. That's it! + +reactor+controlnet + +### **V. "ModuleNotFoundError: No module named 'basicsr'" or "subprocess-exited-with-error" during future-0.18.3 installation** + +- Download https://github.com/Gourieff/Assets/raw/main/comfyui-reactor-node/future-0.18.3-py3-none-any.whl
+- Put it to ComfyUI root And run: + + python_embeded\python.exe -m pip install future-0.18.3-py3-none-any.whl + +- Then: + + python_embeded\python.exe -m pip install basicsr + +### **VI. "fatal: fetch-pack: invalid index-pack output" when you try to `git clone` the repository"** + +Try to clone with `--depth=1` (last commit only): + + git clone --depth=1 https://github.com/Gourieff/ComfyUI-ReActor + +Then retrieve the rest (if you need): + + git fetch --unshallow + +## Updating + +Just put .bat or .sh script from this [Repo](https://github.com/Gourieff/sd-webui-extensions-updater) to the `ComfyUI\custom_nodes` directory and run it when you need to check for updates + +### Disclaimer + +This software is meant to be a productive contribution to the rapidly growing AI-generated media industry. It will help artists with tasks such as animating a custom character or using the character as a model for clothing etc. + +The developers of this software are aware of its possible unethical applications and are committed to take preventative measures against them. We will continue to develop this project in the positive direction while adhering to law and ethics. + +Users of this software are expected to use this software responsibly while abiding the local law. If face of a real person is being used, users are suggested to get consent from the concerned person and clearly mention that it is a deepfake when posting content online. **Developers and Contributors of this software are not responsible for actions of end-users.** + +By using this extension you are agree not to create any content that: +- violates any laws; +- causes any harm to a person or persons; +- propagates (spreads) any information (both public or personal) or images (both public or personal) which could be meant for harm; +- spreads misinformation; +- targets vulnerable groups of people. + +This software utilizes the pre-trained models `buffalo_l` and `inswapper_128.onnx`, which are provided by [InsightFace](https://github.com/deepinsight/insightface/). These models are included under the following conditions: + +[From insighface license](https://github.com/deepinsight/insightface/tree/master/python-package): The InsightFace’s pre-trained models are available for non-commercial research purposes only. This includes both auto-downloading models and manually downloaded models. + +Users of this software must strictly adhere to these conditions of use. The developers and maintainers of this software are not responsible for any misuse of InsightFace’s pre-trained models. + +Please note that if you intend to use this software for any commercial purposes, you will need to train your own models or find models that can be used commercially. + +### Models Hashsum + +#### Safe-to-use models have the following hash: + +inswapper_128.onnx +``` +MD5:a3a155b90354160350efd66fed6b3d80 +SHA256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af +``` + +1k3d68.onnx + +``` +MD5:6fb94fcdb0055e3638bf9158e6a108f4 +SHA256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc +``` + +2d106det.onnx + +``` +MD5:a3613ef9eb3662b4ef88eb90db1fcf26 +SHA256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +``` + +det_10g.onnx + +``` +MD5:4c10eef5c9e168357a16fdd580fa8371 +SHA256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +``` + +genderage.onnx + +``` +MD5:81c77ba87ab38163b0dec6b26f8e2af2 +SHA256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +``` + +w600k_r50.onnx + +``` +MD5:80248d427976241cbd1343889ed132b3 +SHA256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43 +``` + +**Please check hashsums if you download these models from unverified (or untrusted) sources** + +
+ +## Thanks and Credits + +
+ Click to expand + +
+ +|file|source|license| +|----|------|-------| +|[buffalo_l.zip](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/buffalo_l.zip) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) | +| [codeformer-v0.1.0.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/codeformer-v0.1.0.pth) | [sczhou](https://github.com/sczhou/CodeFormer) | ![license](https://img.shields.io/badge/license-non_commercial-red) | +| [GFPGANv1.3.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.3.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) | +| [GFPGANv1.4.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.4.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) | +| [inswapper_128.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) | +| [inswapper_128_fp16.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx) | [Hillobar](https://github.com/Hillobar/Rope) | ![license](https://img.shields.io/badge/license-non_commercial-red) | + +[BasicSR](https://github.com/XPixelGroup/BasicSR) - [@XPixelGroup](https://github.com/XPixelGroup)
+[facexlib](https://github.com/xinntao/facexlib) - [@xinntao](https://github.com/xinntao)
+ +[@s0md3v](https://github.com/s0md3v), [@henryruhs](https://github.com/henryruhs) - the original Roop App
+[@ssitu](https://github.com/ssitu) - the first version of [ComfyUI_roop](https://github.com/ssitu/ComfyUI_roop) extension + +
+ + + +### Note! + +**If you encounter any errors when you use ReActor Node - don't rush to open an issue, first try to remove current ReActor node in your workflow and add it again** + +**ReActor Node gets updates from time to time, new functions appear and old node can work with errors or not work at all** diff --git a/custom_nodes/ComfyUI-ReActor/README_RU.md b/custom_nodes/ComfyUI-ReActor/README_RU.md new file mode 100644 index 0000000000000000000000000000000000000000..0f9962d14c3720197a0983383c79db9872d651c5 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/README_RU.md @@ -0,0 +1,497 @@ +
+ + logo + + ![Version](https://img.shields.io/badge/версия_нода-0.6.0_alpha1-lightgreen?style=for-the-badge&labelColor=darkgreen) + + + + + Поддержать проект на Boosty +
+ + Поддержать проект + +
+ +
+ + [![Commit activity](https://img.shields.io/github/commit-activity/t/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/commits/main) + ![Last commit](https://img.shields.io/github/last-commit/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0) + [![Opened issues](https://img.shields.io/github/issues/Gourieff/ComfyUI-ReActor?color=red)](https://github.com/Gourieff/ComfyUI-ReActor/issues?cacheSeconds=0) + [![Closed issues](https://img.shields.io/github/issues-closed/Gourieff/ComfyUI-ReActor?color=green&cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/issues?q=is%3Aissue+state%3Aclosed) + ![License](https://img.shields.io/github/license/Gourieff/ComfyUI-ReActor) + + [English](/README.md) | Русский + +# ReActor Nodes для ComfyUI
-=Безопасно для работы | SFW-Friendly=- + +
+ +### Ноды (nodes) для быстрой и простой замены лиц на любых изображениях для работы с ComfyUI, основан на [ранее заблокированном РеАкторе](https://github.com/Gourieff/comfyui-reactor-node) - теперь имеется встроенный NSFW-детектор, исключающий замену лиц на изображениях с контентом 18+ + +> Используя данное ПО, вы понимаете и принимаете [ответственность](#disclaimer) + +
+ +--- +[**Что нового**](#latestupdate) | [**Установка**](#installation) | [**Использование**](#usage) | [**Устранение проблем**](#troubleshooting) | [**Обновление**](#updating) | [**Ответственность**](#disclaimer) | [**Благодарности**](#credits) | [**Заметка**](#note) + +--- + +
+ + + +## Что нового в последнем обновлении + +### 0.6.0 ALPHA1 + +- Новый нод `ReActorSetWeight` - теперь можно установить силу замены лица для `source_image` или `face_model` от 0% до 100% (с шагом 12.5%) + +
+0.6.0-whatsnew-01 +0.6.0-whatsnew-02 +0.6.0-whatsnew-03 +
+ +
+ Предыдущие версии + +### 0.5.2 + +- Поддержка моделей ReSwapper. Несмотря на то, что Inswapper по-прежнему даёт лучшее сходство, но ReSwapper развивается - спасибо @somanchiu https://github.com/somanchiu/ReSwapper за эти модели и проект ReSwapper! Это хороший шаг для Сообщества в создании альтернативы Инсваппера! + +
+0.5.2-whatsnew-03 +0.5.2-whatsnew-04 +
+ +Скачать модели ReSwapper можно отсюда: +https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models +Сохраните их в директорию "models/reswapper". + +- NSFW-детектор, чтобы не нарушать [правила GitHub](https://docs.github.com/en/site-policy/acceptable-use-policies/github-misinformation-and-disinformation#synthetic--manipulated-media-tools) +- Новый нод "Unload ReActor Models" - полезен для сложных воркфлоу, когда вам нужно освободить ОЗУ, занятую РеАктором + +0.5.2-whatsnew-01 + +- Поддержка ORT CoreML and ROCM EPs, достаточно установить ту версию onnxruntime, которая соответствует вашему GPU +- Некоторые улучшения скрипта установки для поддержки последней версии ORT-GPU + +
+0.5.2-whatsnew-02 +
+ +- Исправления и улучшения + +### 0.5.1 + +- Поддержка моделей восстановления лиц GPEN 1024/2048 (доступны в датасете на HF https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models) +- Нод ReActorFaceBoost - попытка улучшить качество заменённых лиц. Идея состоит в том, чтобы восстановить и увеличить заменённое лицо (в соответствии с параметром `face_size` модели реставрации) ДО того, как лицо будет вставлено в целевое изображения (через алгоритмы инсваппера), больше информации [здесь (PR#321)](https://github.com/Gourieff/comfyui-reactor-node/pull/321) + +0.5.1-whatsnew-01 + +[Полноразмерное демо-превью](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-02.png) + +- Сортировка моделей лиц по алфавиту +- Множество исправлений и улучшений + +### [0.5.0 BETA4](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.5.0) + +- Поддержка библиотеки Spandrel при работе с GFPGAN + +### 0.5.0 BETA3 + +- Исправления: "RAM issue", "No detection" для MaskingHelper + +### 0.5.0 BETA2 + +- Появилась возможность строить смешанные модели лиц из пачки уже имеющихся моделей - добавьте для этого нод "Make Face Model Batch" в свой воркфлоу и загрузите несколько моделей через ноды "Load Face Model" +- Огромный буст производительности модуля анализа изображений! 10-кратный прирост скорости! Работа с видео теперь в удовольствие! + +0.5.0-whatsnew-05 + +### 0.5.0 BETA1 + +- Добавлен выход SWAPPED_FACE для нода Masking Helper +- FIX: Удалён пустой A-канал на выходе IMAGE нода Masking Helper (вызывавший ошибки с некоторым нодами) + +### 0.5.0 ALPHA1 + +- Нод ReActorBuildFaceModel получил выход "face_model" для отправки совмещенной модели лиц непосредственно в основной Нод: + +Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v2.json) + +- Функции маски лица теперь доступна и в версии для Комфи, просто добавьте нод "ReActorMaskHelper" в воркфлоу и соедините узлы, как показано ниже: + +0.5.0-whatsnew-01 + +Если модель "face_yolov8m.pt" у вас отсутствует - можете скачать её [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) и положить в папку "ComfyUI\models\ultralytics\bbox" +
+То же самое и с ["sam_vit_b_01ec64.pth"](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth) - скачайте (если отсутствует) и положите в папку "ComfyUI\models\sams"; + +Данный нод поможет вам получить куда более аккуратный результат при замене лиц: + +0.5.0-whatsnew-02 + +- Нод ReActorImageDublicator - полезен тем, кто создает видео, помогает продублировать одиночное изображение в несколько копий, чтобы использовать их, к примеру, с VAE энкодером: + +0.5.0-whatsnew-03 + +- ReActorFaceSwapOpt (упрощенная версия основного нода) + нод ReActorOptions для установки дополнительных опций, как (новые) "отдельный порядок лиц для input/source". Да! Теперь можно установить любой порядок "чтения" индекса лиц на изображении, в т.ч. от большего к меньшему (по умолчанию)! + +0.5.0-whatsnew-04 + +- Небольшое улучшение скорости анализа целевых изображений (input) + +### [0.4.2](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.2) + +- Добавлена поддержка GPEN-BFR-512 и RestoreFormer_Plus_Plus моделей восстановления лиц + +Скачать можно здесь: https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models +
Добавьте модели в папку `ComfyUI\models\facerestore_models` + +0.4.2-whatsnew-04 + +- По многочисленным просьбам появилась возможность строить смешанные модели лиц и в ComfyUI тоже и использовать их с нодом "Load Face Model" Node или в SD WebUI; + +Экспериментируйте и создавайте новые лица или совмещайте разные лица нужного вам персонажа, чтобы добиться лучшей точности и схожести с оригиналом! + +Достаточно добавить нод "Make Image Batch" (ImpactPack) на вход нового нода РеАктора и загрузить в пачку необходимые вам изображения для построения смешанной модели: + +0.4.2-whatsnew-01 + +Пример результата (на основе лиц 4-х актрис создано новое лицо): + +0.4.2-whatsnew-02 + +Базовый воркфлоу [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v1.json) + +### [0.4.1](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.1) + +- Поддержка CUDA 12 - не забудьте запустить (Windows) `install.bat` или (Linux/MacOS) `install.py` для используемого Python окружения или попробуйте установить ORT-GPU для CU12 вручную (https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x) +- Исправление Issue https://github.com/Gourieff/comfyui-reactor-node/issues/173 + +- Отдельный Нод для восстаноления лиц (FR https://github.com/Gourieff/comfyui-reactor-node/issues/191), располагается внутри меню ReActor (нод RestoreFace) +- (Windows) Установка зависимостей теперь может быть выполнена в Python из PATH ОС +- Разные исправления и улучшения + +- Face Restore Visibility и CodeFormer Weight (Fidelity) теперь доступны; не забудьте заново добавить Нод в ваших существующих воркфлоу + +0.4.1-whatsnew-01 + +### [0.4.0](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.0) + +- Вход "input_image" теперь идёт первым, это даёт возможность корректного байпаса, а также это правильно с точки зрения расположения входов (главный вход - первый); +- Теперь можно сохранять модели лиц в качестве файлов "safetensors" (`ComfyUI\models\reactor\faces`) и загружать их в ReActor, реализуя разные сценарии использования, а также храня супер легкие модели лиц, которые вы чаще всего используете: + +0.4.0-whatsnew-01 +0.4.0-whatsnew-02 + +- Возможность сохранять модели лиц напрямую из изображения: + +0.4.0-whatsnew-03 + +- Оба входа опциональны, присоедините один из них в соответствии с вашим воркфлоу; если присоеденены оба - вход `image` имеет приоритет. +- Различные исправления, делающие это расширение лучше. + +Спасибо всем, кто находит ошибки, предлагает новые функции и поддерживает данный проект! + +
+ + + +## Установка + +
+ SD WebUI: AUTOMATIC1111 или SD.Next + +1. Закройте (остановите) SD-WebUI Сервер, если запущен +2. (Для пользователей Windows): + - Установите [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Например, версию Community - этот шаг нужен для правильной компиляции библиотеки Insightface) + - ИЛИ только [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/), выберите "Desktop Development with C++" в разделе "Workloads -> Desktop & Mobile" + - ИЛИ если же вы не хотите устанавливать что-либо из вышеуказанного - выполните [данные шаги (раздел. I)](#insightfacebuild) +3. Перейдите в `extensions\sd-webui-comfyui\ComfyUI\custom_nodes` +4. Откройте Консоль или Терминал и выполните `git clone https://github.com/Gourieff/ComfyUI-ReActor` +5. Перейдите в корневую директорию SD WebUI, откройте Консоль или Терминал и выполните (для пользователей Windows)`.\venv\Scripts\activate` или (для пользователей Linux/MacOS)`venv/bin/activate` +6. `python -m pip install -U pip` +7. `cd extensions\sd-webui-comfyui\ComfyUI\custom_nodes\ComfyUI-ReActor` +8. `python install.py` +9. Пожалуйста, дождитесь полного завершения установки +10. (Начиная с версии 0.3.0) Скачайте дополнительные модели восстановления лиц (по ссылке ниже) и сохраните их в папку `extensions\sd-webui-comfyui\ComfyUI\models\facerestore_models`:
+https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models +11. Запустите SD WebUI и проверьте консоль на сообщение, что ReActor Node работает: +console_status_running + +12. Перейдите во вкладку ComfyUI и найдите там ReActor Node внутри меню `ReActor` или через поиск: +webui-demo +webui-demo + +
+ +
+ Портативная версия ComfyUI для Windows + +1. Сделайте следующее: + - Установите [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Например, версию Community - этот шаг нужен для правильной компиляции библиотеки Insightface) + - ИЛИ только [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/), выберите "Desktop Development with C++" в разделе "Workloads -> Desktop & Mobile" + - ИЛИ если же вы не хотите устанавливать что-либо из вышеуказанного - выполните [данные шаги (раздел. I)](#insightfacebuild) +2. Выберите из двух вариантов: + - (ComfyUI Manager) Откройте ComfyUI Manager, нажвите "Install Custom Nodes", введите "ReActor" в поле "Search" и далее нажмите "Install". После того, как ComfyUI завершит установку, перезагрузите сервер. + - (Вручную) Перейдите в `ComfyUI\custom_nodes`, откройте Консоль и выполните `git clone https://github.com/Gourieff/ComfyUI-ReActor` +3. Перейдите `ComfyUI\custom_nodes\ComfyUI-ReActor` и запустите `install.bat`, дождитесь окончания установки +4. Если модель "face_yolov8m.pt" у вас отсутствует - можете скачать её [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) и положить в папку "ComfyUI\models\ultralytics\bbox"
+То же самое и с "Sams" моделями, скачайте одну или обе [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/sams) - и положите в папку "ComfyUI\models\sams" +5. Запустите ComfyUI и найдите ReActor Node внутри меню `ReActor` или через поиск + +
+ + + +## Использование + +Вы можете найти ноды ReActor внутри меню `ReActor` или через поиск (достаточно ввести "ReActor" в поисковой строке) + +Список нодов: +- ••• Main Nodes ••• + - ReActorFaceSwap (Основной нод) + - ReActorFaceSwapOpt (Основной нод с доп. входом Options) + - ReActorOptions (Опции для ReActorFaceSwapOpt) + - ReActorFaceBoost (Нод Face Booster) + - ReActorMaskHelper (Masking Helper) +- ••• Operations with Face Models ••• + - ReActorSaveFaceModel (Save Face Model) + - ReActorLoadFaceModel (Load Face Model) + - ReActorBuildFaceModel (Build Blended Face Model) + - ReActorMakeFaceModelBatch (Make Face Model Batch) +- ••• Additional Nodes ••• + - ReActorRestoreFace (Face Restoration) + - ReActorImageDublicator (Dublicate one Image to Images List) + - ImageRGBA2RGB (Convert RGBA to RGB) + +Соедините все необходимые слоты (slots) и запустите очередь (query). + +### Входы основного Нода + +- `input_image` - это изображение, на котором надо поменять лицо или лица (целевое изображение, аналог "target image" в версии для SD WebUI); + - Поддерживаемые ноды: "Load Image", "Load Video" или любые другие ноды предоставляющие изображение в качестве выхода; +- `source_image` - это изображение с лицом или лицами для замены (изображение-источник, аналог "source image" в версии для SD WebUI); + - Поддерживаемые ноды: "Load Image" или любые другие ноды с выходом Image(s); +- `face_model` - это вход для выхода с нода "Load Face Model" или другого нода ReActor для загрузки модели лица (face model или face embedding), которое вы создали ранее через нод "Save Face Model"; + - Поддерживаемые ноды: "Load Face Model", "Build Blended Face Model"; + +### Выходы основного Нода + +- `IMAGE` - выход с готовым изображением (результатом); + - Поддерживаемые ноды: любые ноды с изображением на входе; +- `FACE_MODEL` - выход, предоставляющий модель лица, построенную в ходе замены; + - Поддерживаемые ноды: "Save Face Model", "ReActor", "Make Face Model Batch"; + +### Восстановление лиц + +Начиная с версии 0.3.0 ReActor Node имеет встроенное восстановление лиц.
Скачайте нужные вам модели (см. инструкцию по [Установке](#installation)) и выберите одну из них, чтобы улучшить качество финального лица. + +### Индексы Лиц (Face Indexes) + +По умолчанию ReActor определяет лица на изображении в порядке от "большого" к "малому".
Вы можете поменять эту опцию, используя нод ReActorFaceSwapOpt вместе с ReActorOptions. + +Если вам нужно заменить определенное лицо, вы можете указать индекс для исходного (source, с лицом) и входного (input, где будет замена лица) изображений. + +Индекс первого обнаруженного лица - 0. + +Вы можете задать индексы в том порядке, который вам нужен.
+Например: 0,1,2 (для Source); 1,0,2 (для Input).
Это означает, что: второе лицо из Input (индекс = 1) будет заменено первым лицом из Source (индекс = 0) и так далее. + +### Определение Пола + +Вы можете обозначить, какой пол нужно определять на изображении.
+ReActor заменит только то лицо, которое удовлетворяет заданному условию. + +### Модели Лиц +Начиная с версии 0.4.0, вы можете сохранять модели лиц как файлы "safetensors" (хранятся в папке `ComfyUI\models\reactor\faces`) и загружать их в ReActor, реализуя разные сценарии использования, а также храня супер легкие модели лиц, которые вы чаще всего используете. + +Чтобы новые модели появились в списке моделей нода "Load Face Model" - обновите страницу of с ComfyUI.
+(Рекомендую использовать ComfyUI Manager - иначе ваше воркфлоу может быть потеряно после перезагрузки страницы, если вы не сохранили его). + +
+ +## Устранение проблем + + + +### **I. (Для пользователей Windows) Если вы до сих пор не можете установить пакет Insightface по каким-то причинам или же просто не желаете устанавливать Visual Studio или VS C++ Build Tools - сделайте следующее:** + +1. (ComfyUI Portable) Находясь в корневой директории, проверьте версию Python:
запустите CMD и выполните `python_embeded\python.exe -V`
Вы должны увидеть версию или 3.10, или 3.11, или 3.12 +2. Скачайте готовый пакет Insightface [для версии 3.10](https://github.com/Gourieff/sd-webui-reactor/raw/main/example/insightface-0.7.3-cp310-cp310-win_amd64.whl) или [для 3.11](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl) (если на предыдущем шаге вы увидели 3.11) или [для 3.12](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl) (если на предыдущем шаге вы увидели 3.12) и сохраните его в корневую директорию stable-diffusion-webui (A1111 или SD.Next) - туда, где лежит файл "webui-user.bat" -ИЛИ- в корневую директорию ComfyUI, если вы используете ComfyUI Portable +3. Из корневой директории запустите: + - (SD WebUI) CMD и `.\venv\Scripts\activate` + - (ComfyUI Portable) CMD +4. Обновите PIP: + - (SD WebUI) `python -m pip install -U pip` + - (ComfyUI Portable) `python_embeded\python.exe -m pip install -U pip` +5. Затем установите Insightface: + - (SD WebUI) `pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (для 3.10) или `pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (для 3.11) или `pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12) + - (ComfyUI Portable) `python_embeded\python.exe -m pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (для 3.10) или `python_embeded\python.exe -m pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (для 3.11) или `python_embeded\python.exe -m pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12) +6. Готово! + +### **II. "AttributeError: 'NoneType' object has no attribute 'get'"** + +Эта ошибка появляется, если что-то не так с файлом модели `inswapper_128.onnx` + +Скачайте вручную по ссылке [отсюда](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) +и сохраните в директорию `ComfyUI\models\insightface`, заменив имеющийся файл + +### **III. "reactor.execute() got an unexpected keyword argument 'reference_image'"** + +Это означает, что поменялось обозначение входных точек (input points) всвязи с последним обновлением
+Удалите из вашего рабочего пространства имеющийся ReActor Node и добавьте его снова + +### **IV. ControlNet Aux Node IMPORT failed - при использовании совместно с нодом ReActor** + +1. Закройте или остановите ComfyUI сервер, если он запущен +2. Перейдите в корневую папку ComfyUI, откройте консоль CMD и выполните следующее: + - `python_embeded\python.exe -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless` + - `python_embeded\python.exe -m pip install opencv-python==4.7.0.72` +3. Готово! + +reactor+controlnet + +### **V. "ModuleNotFoundError: No module named 'basicsr'" или "subprocess-exited-with-error" при установке пакета future-0.18.3** + +- Скачайте https://github.com/Gourieff/Assets/raw/main/comfyui-reactor-node/future-0.18.3-py3-none-any.whl
+- Скопируйте файл в корневую папку ComfyUI и выполните в консоли: + + python_embeded\python.exe -m pip install future-0.18.3-py3-none-any.whl + +- Затем: + + python_embeded\python.exe -m pip install basicsr + +### **VI. "fatal: fetch-pack: invalid index-pack output" при исполнении команды `git clone`"** + +Попробуйте клонировать репозиторий с параметром `--depth=1` (только последний коммит): + + git clone --depth=1 https://github.com/Gourieff/ComfyUI-ReActor + +Затем вытяните оставшееся (если требуется): + + git fetch --unshallow + +
+ +## Обновление + +Положите .bat или .sh скрипт из [данного репозитория](https://github.com/Gourieff/sd-webui-extensions-updater) в папку `ComfyUI\custom_nodes` и запустите, когда желаете обновить ComfyUI и Ноды + + + +## Ответственность + +Это программное обеспечение призвано стать продуктивным вкладом в быстрорастущую медиаиндустрию на основе генеративных сетей и искусственного интеллекта. Данное ПО поможет художникам в решении таких задач, как анимация собственного персонажа или использование персонажа в качестве модели для одежды и т.д. + +Разработчики этого программного обеспечения осведомлены о возможных неэтичных применениях и обязуются принять против этого превентивные меры. Мы продолжим развивать этот проект в позитивном направлении, придерживаясь закона и этики. + +Подразумевается, что пользователи этого программного обеспечения будут использовать его ответственно, соблюдая локальное законодательство. Если используется лицо реального человека, пользователь обязан получить согласие заинтересованного лица и четко указать, что это дипфейк при размещении контента в Интернете. **Разработчики и Со-авторы данного программного обеспечения не несут ответственности за действия конечных пользователей.** + +Используя данное расширение, вы соглашаетесь не создавать материалы, которые: +- нарушают какие-либо действующие законы тех или иных государств или международных организаций; +- причиняют какой-либо вред человеку или лицам; +- пропагандируют любую информацию (как общедоступную, так и личную) или изображения (как общедоступные, так и личные), которые могут быть направлены на причинение вреда; +- используются для распространения дезинформации; +- нацелены на уязвимые группы людей. + +Данное программное обеспечение использует предварительно обученные модели `buffalo_l` и `inswapper_128.onnx`, представленные разработчиками [InsightFace](https://github.com/deepinsight/insightface/). Эти модели распространяются при следующих условиях: + +[Перевод из текста лицензии insighface](https://github.com/deepinsight/insightface/tree/master/python-package): Предварительно обученные модели InsightFace доступны только для некоммерческих исследовательских целей. Сюда входят как модели с автоматической загрузкой, так и модели, загруженные вручную. + +Пользователи данного программного обеспечения должны строго соблюдать данные условия использования. Разработчики и Со-авторы данного программного продукта не несут ответственности за неправильное использование предварительно обученных моделей InsightFace. + +Обратите внимание: если вы собираетесь использовать это программное обеспечение в каких-либо коммерческих целях, вам необходимо будет обучить свои собственные модели или найти модели, которые можно использовать в коммерческих целях. + +### Хэш файлов моделей + +#### Безопасные для использования модели имеют следующий хэш: + +inswapper_128.onnx +``` +MD5:a3a155b90354160350efd66fed6b3d80 +SHA256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af +``` + +1k3d68.onnx + +``` +MD5:6fb94fcdb0055e3638bf9158e6a108f4 +SHA256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc +``` + +2d106det.onnx + +``` +MD5:a3613ef9eb3662b4ef88eb90db1fcf26 +SHA256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +``` + +det_10g.onnx + +``` +MD5:4c10eef5c9e168357a16fdd580fa8371 +SHA256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +``` + +genderage.onnx + +``` +MD5:81c77ba87ab38163b0dec6b26f8e2af2 +SHA256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +``` + +w600k_r50.onnx + +``` +MD5:80248d427976241cbd1343889ed132b3 +SHA256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43 +``` + +**Пожалуйста, сравните хэш, если вы скачиваете данные модели из непроверенных источников** + + + +## Благодарности и авторы компонентов + +
+ Нажмите, чтобы посмотреть + +
+ +|файл|источник|лицензия| +|----|--------|--------| +|[buffalo_l.zip](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/buffalo_l.zip) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) | +| [codeformer-v0.1.0.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/codeformer-v0.1.0.pth) | [sczhou](https://github.com/sczhou/CodeFormer) | ![license](https://img.shields.io/badge/license-non_commercial-red) | +| [GFPGANv1.3.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.3.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) | +| [GFPGANv1.4.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.4.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) | +| [inswapper_128.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) | +| [inswapper_128_fp16.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx) | [Hillobar](https://github.com/Hillobar/Rope) | ![license](https://img.shields.io/badge/license-non_commercial-red) | + +[BasicSR](https://github.com/XPixelGroup/BasicSR) - [@XPixelGroup](https://github.com/XPixelGroup)
+[facexlib](https://github.com/xinntao/facexlib) - [@xinntao](https://github.com/xinntao)
+ +[@s0md3v](https://github.com/s0md3v), [@henryruhs](https://github.com/henryruhs) - оригинальное приложение Roop
+[@ssitu](https://github.com/ssitu) - первая версия расширения с поддержкой ComfyUI [ComfyUI_roop](https://github.com/ssitu/ComfyUI_roop) + +
+ + + +### Обратите внимание! + +**Если у вас возникли какие-либо ошибки при очередном использовании Нода ReActor - не торопитесь открывать Issue, для начала попробуйте удалить текущий Нод из вашего рабочего пространства и добавить его снова** + +**ReActor Node периодически получает обновления, появляются новые функции, из-за чего имеющийся Нод может работать с ошибками или не работать вовсе** diff --git a/custom_nodes/ComfyUI-ReActor/__init__.py b/custom_nodes/ComfyUI-ReActor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eac9effa0d4d0b04f4a96f004cb295be55f99ba1 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/__init__.py @@ -0,0 +1,39 @@ +import sys +import os + +repo_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.insert(0, repo_dir) +original_modules = sys.modules.copy() + +# Place aside existing modules if using a1111 web ui +modules_used = [ + "modules", + "modules.images", + "modules.processing", + "modules.scripts_postprocessing", + "modules.scripts", + "modules.shared", +] +original_webui_modules = {} +for module in modules_used: + if module in sys.modules: + original_webui_modules[module] = sys.modules.pop(module) + +# Proceed with node setup +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] + +# Clean up imports +# Remove repo directory from path +sys.path.remove(repo_dir) +# Remove any new modules +modules_to_remove = [] +for module in sys.modules: + if module not in original_modules and not module.startswith("google.protobuf") and not module.startswith("onnx") and not module.startswith("cv2"): + modules_to_remove.append(module) +for module in modules_to_remove: + del sys.modules[module] + +# Restore original modules +sys.modules.update(original_webui_modules) diff --git a/custom_nodes/ComfyUI-ReActor/install.bat b/custom_nodes/ComfyUI-ReActor/install.bat new file mode 100644 index 0000000000000000000000000000000000000000..290195e8e64f4bfd161828f5e8d224665296feea --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/install.bat @@ -0,0 +1,37 @@ +@echo off +setlocal enabledelayedexpansion + +:: Try to use embedded python first +if exist ..\..\..\python_embeded\python.exe ( + :: Use the embedded python + set PYTHON=..\..\..\python_embeded\python.exe +) else ( + :: Embedded python not found, check for python in the PATH + for /f "tokens=* USEBACKQ" %%F in (`python --version 2^>^&1`) do ( + set PYTHON_VERSION=%%F + ) + if errorlevel 1 ( + echo I couldn't find an embedded version of Python, nor one in the Windows PATH. Please install manually. + pause + exit /b 1 + ) else ( + :: Use python from the PATH (if it's the right version and the user agrees) + echo I couldn't find an embedded version of Python, but I did find !PYTHON_VERSION! in your Windows PATH. + echo Would you like to proceed with the install using that version? (Y/N^) + set /p USE_PYTHON= + if /i "!USE_PYTHON!"=="Y" ( + set PYTHON=python + ) else ( + echo Okay. Please install manually. + pause + exit /b 1 + ) + ) +) + +:: Install the package +echo Installing... +%PYTHON% install.py +echo Done^! + +@pause \ No newline at end of file diff --git a/custom_nodes/ComfyUI-ReActor/install.py b/custom_nodes/ComfyUI-ReActor/install.py new file mode 100644 index 0000000000000000000000000000000000000000..14cd972f0e694a56dff3997fe142d4193d4433aa --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/install.py @@ -0,0 +1,104 @@ +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) + +import subprocess +import os, sys +try: + from pkg_resources import get_distribution as distributions +except: + from importlib_metadata import distributions +from tqdm import tqdm +import urllib.request +from packaging import version as pv +try: + from folder_paths import models_dir +except: + from pathlib import Path + models_dir = os.path.join(Path(__file__).parents[2], "models") + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) + +req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") + +model_url = "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/inswapper_128.onnx" +model_name = os.path.basename(model_url) +models_dir_path = os.path.join(models_dir, "insightface") +model_path = os.path.join(models_dir_path, model_name) + +def run_pip(*args): + subprocess.run([sys.executable, "-m", "pip", "install", "--no-warn-script-location", *args]) + +def is_installed ( + package: str, version: str = None, strict: bool = True +): + has_package = None + try: + has_package = distributions(package) + if has_package is not None: + if version is not None: + installed_version = has_package.version + if (installed_version != version and strict == True) or (pv.parse(installed_version) < pv.parse(version) and strict == False): + return False + else: + return True + else: + return True + else: + return False + except Exception as e: + print(f"Status: {e}") + return False + +def download(url, path, name): + request = urllib.request.urlopen(url) + total = int(request.headers.get('Content-Length', 0)) + with tqdm(total=total, desc=f'[ReActor] Downloading {name} to {path}', unit='B', unit_scale=True, unit_divisor=1024) as progress: + urllib.request.urlretrieve(url, path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) + +if not os.path.exists(models_dir_path): + os.makedirs(models_dir_path) + +if not os.path.exists(model_path): + download(model_url, model_path, model_name) + +with open(req_file) as file: + try: + ort = "onnxruntime-gpu" + import torch + cuda_version = None + if torch.cuda.is_available(): + cuda_version = torch.version.cuda + print(f"CUDA {cuda_version}") + elif torch.backends.mps.is_available() or hasattr(torch,'dml') or hasattr(torch,'privateuseone'): + ort = "onnxruntime" + if cuda_version is not None and float(cuda_version)>=12 and torch.torch_version.__version__ <= "2.2.0": # CU12.x and torch<=2.2.0 + print(f"Torch: {torch.torch_version.__version__}") + if not is_installed(ort,"1.17.0",False): + run_pip(ort,"--extra-index-url", "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/") + elif cuda_version is not None and float(cuda_version)>=12 and torch.torch_version.__version__ >= "2.4.0" : # CU12.x and latest torch + print(f"Torch: {torch.torch_version.__version__}") + if not is_installed(ort,"1.20.1",False): # latest ort-gpu + run_pip(ort,"-U") + elif not is_installed(ort,"1.16.1",False): + run_pip(ort, "-U") + except Exception as e: + print(e) + print(f"Warning: Failed to install {ort}, ReActor will not work.") + raise e + strict = True + for package in file: + package_version = None + try: + package = package.strip() + if "==" in package: + package_version = package.split('==')[1] + elif ">=" in package: + package_version = package.split('>=')[1] + strict = False + if not is_installed(package,package_version,strict): + run_pip(package) + except Exception as e: + print(e) + print(f"Warning: Failed to install {package}, ReActor will not work.") + raise e +print("Ok") diff --git a/custom_nodes/ComfyUI-ReActor/modules/__init__.py b/custom_nodes/ComfyUI-ReActor/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/modules/images.py b/custom_nodes/ComfyUI-ReActor/modules/images.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/modules/processing.py b/custom_nodes/ComfyUI-ReActor/modules/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..d52541824a5d566516b8f8d1a24ad626a090f6cf --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/modules/processing.py @@ -0,0 +1,13 @@ +class StableDiffusionProcessing: + + def __init__(self, init_imgs): + self.init_images = init_imgs + self.width = init_imgs[0].width + self.height = init_imgs[0].height + self.extra_generation_params = {} + + +class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): + + def __init__(self, init_img): + super().__init__(init_img) diff --git a/custom_nodes/ComfyUI-ReActor/modules/scripts.py b/custom_nodes/ComfyUI-ReActor/modules/scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..5eae950c9db521fbdadf06ceeb7cd40de2bb18cc --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/modules/scripts.py @@ -0,0 +1,13 @@ +import os + + +class Script: + pass + + +def basedir(): + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class PostprocessImageArgs: + pass diff --git a/custom_nodes/ComfyUI-ReActor/modules/scripts_postprocessing.py b/custom_nodes/ComfyUI-ReActor/modules/scripts_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/modules/shared.py b/custom_nodes/ComfyUI-ReActor/modules/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..01263864aa9b0d94462b1466d44f61d295a9d9e4 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/modules/shared.py @@ -0,0 +1,19 @@ +class Options: + img2img_background_color = "#ffffff" # Set to white for now + + +class State: + interrupted = False + + def begin(self): + pass + + def end(self): + pass + + +opts = Options() +state = State() +cmd_opts = None +sd_upscalers = [] +face_restorers = [] diff --git a/custom_nodes/ComfyUI-ReActor/nodes.py b/custom_nodes/ComfyUI-ReActor/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b09eda0c6cf16c824c04ec8a4f6a3ddda55c36 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/nodes.py @@ -0,0 +1,1364 @@ +import os, glob, sys +import logging + +import torch +import torch.nn.functional as torchfn +from torchvision.transforms.functional import normalize +from torchvision.ops import masks_to_boxes + +import numpy as np +import cv2 +import math +from typing import List +from PIL import Image +from scipy import stats +from insightface.app.common import Face +from segment_anything import sam_model_registry + +from modules.processing import StableDiffusionProcessingImg2Img +from modules.shared import state +# from comfy_extras.chainner_models import model_loading +import comfy.model_management as model_management +import comfy.utils +import folder_paths + +import scripts.reactor_version +from r_chainner import model_loading +from scripts.reactor_faceswap import ( + FaceSwapScript, + get_models, + get_current_faces_model, + analyze_faces, + half_det_size, + providers +) +from scripts.reactor_swapper import ( + unload_all_models, +) +from scripts.reactor_logger import logger +from reactor_utils import ( + batch_tensor_to_pil, + batched_pil_to_tensor, + tensor_to_pil, + img2tensor, + tensor2img, + save_face_model, + load_face_model, + download, + set_ort_session, + prepare_cropped_face, + normalize_cropped_face, + add_folder_path_and_extensions, + rgba2rgb_tensor +) +from reactor_patcher import apply_patch +from r_facelib.utils.face_restoration_helper import FaceRestoreHelper +from r_basicsr.utils.registry import ARCH_REGISTRY +import scripts.r_archs.codeformer_arch +import scripts.r_masking.subcore as subcore +import scripts.r_masking.core as core +import scripts.r_masking.segs as masking_segs + +import scripts.reactor_sfw as sfw + + +models_dir = folder_paths.models_dir +REACTOR_MODELS_PATH = os.path.join(models_dir, "reactor") +FACE_MODELS_PATH = os.path.join(REACTOR_MODELS_PATH, "faces") +NSFWDET_MODEL_PATH = os.path.join(models_dir, "nsfw_detector","vit-base-nsfw-detector") + +if not os.path.exists(REACTOR_MODELS_PATH): + os.makedirs(REACTOR_MODELS_PATH) + if not os.path.exists(FACE_MODELS_PATH): + os.makedirs(FACE_MODELS_PATH) + +dir_facerestore_models = os.path.join(models_dir, "facerestore_models") +os.makedirs(dir_facerestore_models, exist_ok=True) +folder_paths.folder_names_and_paths["facerestore_models"] = ([dir_facerestore_models], folder_paths.supported_pt_extensions) + +BLENDED_FACE_MODEL = None +FACE_SIZE: int = 512 +FACE_HELPER = None + +if "ultralytics" not in folder_paths.folder_names_and_paths: + add_folder_path_and_extensions("ultralytics_bbox", [os.path.join(models_dir, "ultralytics", "bbox")], folder_paths.supported_pt_extensions) + add_folder_path_and_extensions("ultralytics_segm", [os.path.join(models_dir, "ultralytics", "segm")], folder_paths.supported_pt_extensions) + add_folder_path_and_extensions("ultralytics", [os.path.join(models_dir, "ultralytics")], folder_paths.supported_pt_extensions) +if "sams" not in folder_paths.folder_names_and_paths: + add_folder_path_and_extensions("sams", [os.path.join(models_dir, "sams")], folder_paths.supported_pt_extensions) + +def get_facemodels(): + models_path = os.path.join(FACE_MODELS_PATH, "*") + models = glob.glob(models_path) + models = [x for x in models if x.endswith(".safetensors")] + return models + +def get_restorers(): + models_path = os.path.join(models_dir, "facerestore_models/*") + models = glob.glob(models_path) + models = [x for x in models if (x.endswith(".pth") or x.endswith(".onnx"))] + if len(models) == 0: + fr_urls = [ + "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GFPGANv1.3.pth", + "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GFPGANv1.4.pth", + "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/codeformer-v0.1.0.pth", + "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GPEN-BFR-512.onnx", + ] + for model_url in fr_urls: + model_name = os.path.basename(model_url) + model_path = os.path.join(dir_facerestore_models, model_name) + download(model_url, model_path, model_name) + models = glob.glob(models_path) + models = [x for x in models if (x.endswith(".pth") or x.endswith(".onnx"))] + return models + +def get_model_names(get_models): + models = get_models() + names = [] + for x in models: + names.append(os.path.basename(x)) + names.sort(key=str.lower) + names.insert(0, "none") + return names + +def model_names(): + models = get_models() + return {os.path.basename(x): x for x in models} + + +class reactor: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}), + "input_image": ("IMAGE",), + "swap_model": (list(model_names().keys()),), + "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],), + "face_restore_model": (get_model_names(get_restorers),), + "face_restore_visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}), + "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}), + "detect_gender_input": (["no","female","male"], {"default": "no"}), + "detect_gender_source": (["no","female","male"], {"default": "no"}), + "input_faces_index": ("STRING", {"default": "0"}), + "source_faces_index": ("STRING", {"default": "0"}), + "console_log_level": ([0, 1, 2], {"default": 1}), + }, + "optional": { + "source_image": ("IMAGE",), + "face_model": ("FACE_MODEL",), + "face_boost": ("FACE_BOOST",), + }, + "hidden": {"faces_order": "FACES_ORDER"}, + } + + RETURN_TYPES = ("IMAGE","FACE_MODEL") + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def __init__(self): + # self.face_helper = None + self.faces_order = ["large-small", "large-small"] + # self.face_size = FACE_SIZE + self.face_boost_enabled = False + self.restore = True + self.boost_model = None + self.interpolation = "Bicubic" + self.boost_model_visibility = 1 + self.boost_cf_weight = 0.5 + + def restore_face( + self, + input_image, + face_restore_model, + face_restore_visibility, + codeformer_weight, + facedetection, + ): + + result = input_image + + if face_restore_model != "none" and not model_management.processing_interrupted(): + + global FACE_SIZE, FACE_HELPER + + self.face_helper = FACE_HELPER + + faceSize = 512 + if "1024" in face_restore_model.lower(): + faceSize = 1024 + elif "2048" in face_restore_model.lower(): + faceSize = 2048 + + logger.status(f"Restoring with {face_restore_model} | Face Size is set to {faceSize}") + + model_path = folder_paths.get_full_path("facerestore_models", face_restore_model) + + device = model_management.get_torch_device() + + if "codeformer" in face_restore_model.lower(): + + codeformer_net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device) + checkpoint = torch.load(model_path)["params_ema"] + codeformer_net.load_state_dict(checkpoint) + facerestore_model = codeformer_net.eval() + + elif ".onnx" in face_restore_model: + + ort_session = set_ort_session(model_path, providers=providers) + ort_session_inputs = {} + facerestore_model = ort_session + + else: + + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + facerestore_model = model_loading.load_state_dict(sd).eval() + facerestore_model.to(device) + + if faceSize != FACE_SIZE or self.face_helper is None: + self.face_helper = FaceRestoreHelper(1, face_size=faceSize, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device) + FACE_SIZE = faceSize + FACE_HELPER = self.face_helper + + image_np = 255. * result.numpy() + + total_images = image_np.shape[0] + + out_images = [] + + for i in range(total_images): + + if total_images > 1: + logger.status(f"Restoring {i+1}") + + cur_image_np = image_np[i,:, :, ::-1] + + original_resolution = cur_image_np.shape[0:2] + + if facerestore_model is None or self.face_helper is None: + return result + + self.face_helper.clean_all() + self.face_helper.read_image(cur_image_np) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() + + restored_face = None + + for idx, cropped_face in enumerate(self.face_helper.cropped_faces): + + # if ".pth" in face_restore_model: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + + with torch.no_grad(): + + if ".onnx" in face_restore_model: # ONNX models + + for ort_session_input in ort_session.get_inputs(): + if ort_session_input.name == "input": + cropped_face_prep = prepare_cropped_face(cropped_face) + ort_session_inputs[ort_session_input.name] = cropped_face_prep + if ort_session_input.name == "weight": + weight = np.array([ 1 ], dtype = np.double) + ort_session_inputs[ort_session_input.name] = weight + + output = ort_session.run(None, ort_session_inputs)[0][0] + restored_face = normalize_cropped_face(output) + + else: # PTH models + + output = facerestore_model(cropped_face_t, w=codeformer_weight)[0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + + del output + torch.cuda.empty_cache() + + except Exception as error: + + print(f"\tFailed inference: {error}", file=sys.stderr) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + if face_restore_visibility < 1: + restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility + + restored_face = restored_face.astype("uint8") + self.face_helper.add_restored_face(restored_face) + + self.face_helper.get_inverse_affine(None) + + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] + + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_AREA) + + self.face_helper.clean_all() + + # out_images[i] = restored_img + out_images.append(restored_img) + + if state.interrupted or model_management.processing_interrupted(): + logger.status("Interrupted by User") + return input_image + + restored_img_np = np.array(out_images).astype(np.float32) / 255.0 + restored_img_tensor = torch.from_numpy(restored_img_np) + + result = restored_img_tensor + + return result + + def execute(self, enabled, input_image, swap_model, detect_gender_source, detect_gender_input, source_faces_index, input_faces_index, console_log_level, face_restore_model,face_restore_visibility, codeformer_weight, facedetection, source_image=None, face_model=None, faces_order=None, face_boost=None): + + if face_boost is not None: + self.face_boost_enabled = face_boost["enabled"] + self.boost_model = face_boost["boost_model"] + self.interpolation = face_boost["interpolation"] + self.boost_model_visibility = face_boost["visibility"] + self.boost_cf_weight = face_boost["codeformer_weight"] + self.restore = face_boost["restore_with_main_after"] + else: + self.face_boost_enabled = False + + if faces_order is None: + faces_order = self.faces_order + + apply_patch(console_log_level) + + if not enabled: + return (input_image,face_model) + elif source_image is None and face_model is None: + logger.error("Please provide 'source_image' or `face_model`") + return (input_image,face_model) + + if face_model == "none": + face_model = None + + script = FaceSwapScript() + pil_images = batch_tensor_to_pil(input_image) + + # NSFW checker + logger.status("Checking for any unsafe content") + pil_images_sfw = [] + tmp_img = "reactor_tmp.png" + for img in pil_images: + if state.interrupted or model_management.processing_interrupted(): + logger.status("Interrupted by User") + break + img.save(tmp_img) + if not sfw.nsfw_image(tmp_img, NSFWDET_MODEL_PATH): + pil_images_sfw.append(img) + if os.path.exists(tmp_img): + os.remove(tmp_img) + pil_images = pil_images_sfw + # # # + + if len(pil_images) > 0: + + if source_image is not None: + source = tensor_to_pil(source_image) + else: + source = None + p = StableDiffusionProcessingImg2Img(pil_images) + script.process( + p=p, + img=source, + enable=True, + source_faces_index=source_faces_index, + faces_index=input_faces_index, + model=swap_model, + swap_in_source=True, + swap_in_generated=True, + gender_source=detect_gender_source, + gender_target=detect_gender_input, + face_model=face_model, + faces_order=faces_order, + # face boost: + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.boost_model, + face_restore_visibility=self.boost_model_visibility, + codeformer_weight=self.boost_cf_weight, + interpolation=self.interpolation, + ) + result = batched_pil_to_tensor(p.init_images) + + if face_model is None: + current_face_model = get_current_faces_model() + face_model_to_provide = current_face_model[0] if (current_face_model is not None and len(current_face_model) > 0) else face_model + else: + face_model_to_provide = face_model + + if self.restore or not self.face_boost_enabled: + result = reactor.restore_face(self,result,face_restore_model,face_restore_visibility,codeformer_weight,facedetection) + + else: + image_black = Image.new("RGB", (512, 512)) + result = batched_pil_to_tensor([image_black]) + face_model_to_provide = None + + return (result,face_model_to_provide) + + +class ReActorPlusOpt: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}), + "input_image": ("IMAGE",), + "swap_model": (list(model_names().keys()),), + "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],), + "face_restore_model": (get_model_names(get_restorers),), + "face_restore_visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}), + "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}), + }, + "optional": { + "source_image": ("IMAGE",), + "face_model": ("FACE_MODEL",), + "options": ("OPTIONS",), + "face_boost": ("FACE_BOOST",), + } + } + + RETURN_TYPES = ("IMAGE","FACE_MODEL") + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def __init__(self): + # self.face_helper = None + self.faces_order = ["large-small", "large-small"] + self.detect_gender_input = "no" + self.detect_gender_source = "no" + self.input_faces_index = "0" + self.source_faces_index = "0" + self.console_log_level = 1 + # self.face_size = 512 + self.face_boost_enabled = False + self.restore = True + self.boost_model = None + self.interpolation = "Bicubic" + self.boost_model_visibility = 1 + self.boost_cf_weight = 0.5 + + def execute(self, enabled, input_image, swap_model, facedetection, face_restore_model, face_restore_visibility, codeformer_weight, source_image=None, face_model=None, options=None, face_boost=None): + + if options is not None: + self.faces_order = [options["input_faces_order"], options["source_faces_order"]] + self.console_log_level = options["console_log_level"] + self.detect_gender_input = options["detect_gender_input"] + self.detect_gender_source = options["detect_gender_source"] + self.input_faces_index = options["input_faces_index"] + self.source_faces_index = options["source_faces_index"] + + if face_boost is not None: + self.face_boost_enabled = face_boost["enabled"] + self.restore = face_boost["restore_with_main_after"] + else: + self.face_boost_enabled = False + + result = reactor.execute( + self,enabled,input_image,swap_model,self.detect_gender_source,self.detect_gender_input,self.source_faces_index,self.input_faces_index,self.console_log_level,face_restore_model,face_restore_visibility,codeformer_weight,facedetection,source_image,face_model,self.faces_order, face_boost=face_boost + ) + + return result + + +class LoadFaceModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "face_model": (get_model_names(get_facemodels),), + } + } + + RETURN_TYPES = ("FACE_MODEL",) + FUNCTION = "load_model" + CATEGORY = "🌌 ReActor" + + def load_model(self, face_model): + self.face_model = face_model + self.face_models_path = FACE_MODELS_PATH + if self.face_model != "none": + face_model_path = os.path.join(self.face_models_path, self.face_model) + out = load_face_model(face_model_path) + else: + out = None + return (out, ) + + +class ReActorWeight: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_image": ("IMAGE",), + "faceswap_weight": (["0%", "12.5%", "25%", "37.5%", "50%", "62.5%", "75%", "87.5%", "100%"], {"default": "50%"}), + }, + "optional": { + "source_image": ("IMAGE",), + "face_model": ("FACE_MODEL",), + } + } + + RETURN_TYPES = ("IMAGE","FACE_MODEL") + RETURN_NAMES = ("INPUT_IMAGE","FACE_MODEL") + FUNCTION = "set_weight" + + OUTPUT_NODE = True + + CATEGORY = "🌌 ReActor" + + def set_weight(self, input_image, faceswap_weight, face_model=None, source_image=None): + + if input_image is None: + logger.error("Please provide `input_image`") + return (input_image,None) + + if source_image is None and face_model is None: + logger.error("Please provide `source_image` or `face_model`") + return (input_image,None) + + weight = float(faceswap_weight.split("%")[0]) + + images = [] + faces = [] if face_model is None else [face_model] + embeddings = [] if face_model is None else [face_model.embedding] + + if weight == 0: + images = [input_image] + faces = [] + embeddings = [] + elif weight == 100: + if face_model is None: + images = [source_image] + else: + if weight > 50: + images = [input_image] + count = round(100/(100-weight)) + else: + if face_model is None: + images = [source_image] + count = round(100/(weight)) + for i in range(count-1): + if weight > 50: + if face_model is None: + images.append(source_image) + else: + faces.append(face_model) + embeddings.append(face_model.embedding) + else: + images.append(input_image) + + images_list: List[Image.Image] = [] + + apply_patch(1) + + if len(images) > 0: + + for image in images: + img = tensor_to_pil(image) + images_list.append(img) + + for image in images_list: + face = BuildFaceModel.build_face_model(self,image) + if isinstance(face, str): + continue + faces.append(face) + embeddings.append(face.embedding) + + if len(faces) > 0: + blended_embedding = np.mean(embeddings, axis=0) + blended_face = Face( + bbox=faces[0].bbox, + kps=faces[0].kps, + det_score=faces[0].det_score, + landmark_3d_68=faces[0].landmark_3d_68, + pose=faces[0].pose, + landmark_2d_106=faces[0].landmark_2d_106, + embedding=blended_embedding, + gender=faces[0].gender, + age=faces[0].age + ) + if blended_face is None: + no_face_msg = "Something went wrong, please try another set of images" + logger.error(no_face_msg) + + return (input_image,blended_face) + + +class BuildFaceModel: + def __init__(self): + self.output_dir = FACE_MODELS_PATH + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "save_mode": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}), + "send_only": ("BOOLEAN", {"default": False, "label_off": "NO", "label_on": "YES"}), + "face_model_name": ("STRING", {"default": "default"}), + "compute_method": (["Mean", "Median", "Mode"], {"default": "Mean"}), + }, + "optional": { + "images": ("IMAGE",), + "face_models": ("FACE_MODEL",), + } + } + + RETURN_TYPES = ("FACE_MODEL",) + FUNCTION = "blend_faces" + + OUTPUT_NODE = True + + CATEGORY = "🌌 ReActor" + + def build_face_model(self, image: Image.Image, det_size=(640, 640)): + logging.StreamHandler.terminator = "\n" + if image is None: + error_msg = "Please load an Image" + logger.error(error_msg) + return error_msg + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + face_model = analyze_faces(image, det_size) + + if len(face_model) == 0: + print("") + det_size_half = half_det_size(det_size) + face_model = analyze_faces(image, det_size_half) + if face_model is not None and len(face_model) > 0: + print("...........................................................", end=" ") + + if face_model is not None and len(face_model) > 0: + return face_model[0] + else: + no_face_msg = "No face found, please try another image" + # logger.error(no_face_msg) + return no_face_msg + + def blend_faces(self, save_mode, send_only, face_model_name, compute_method, images=None, face_models=None): + global BLENDED_FACE_MODEL + blended_face: Face = BLENDED_FACE_MODEL + + if send_only and blended_face is None: + send_only = False + + if (images is not None or face_models is not None) and not send_only: + + faces = [] + embeddings = [] + + apply_patch(1) + + if images is not None: + images_list: List[Image.Image] = batch_tensor_to_pil(images) + + n = len(images_list) + + for i,image in enumerate(images_list): + logging.StreamHandler.terminator = " " + logger.status(f"Building Face Model {i+1} of {n}...") + face = self.build_face_model(image) + if isinstance(face, str): + logger.error(f"No faces found in image {i+1}, skipping") + continue + else: + print(f"{int(((i+1)/n)*100)}%") + faces.append(face) + embeddings.append(face.embedding) + + elif face_models is not None: + + n = len(face_models) + + for i,face_model in enumerate(face_models): + logging.StreamHandler.terminator = " " + logger.status(f"Extracting Face Model {i+1} of {n}...") + face = face_model + if isinstance(face, str): + logger.error(f"No faces found for face_model {i+1}, skipping") + continue + else: + print(f"{int(((i+1)/n)*100)}%") + faces.append(face) + embeddings.append(face.embedding) + + logging.StreamHandler.terminator = "\n" + if len(faces) > 0: + # compute_method_name = "Mean" if compute_method == 0 else "Median" if compute_method == 1 else "Mode" + logger.status(f"Blending with Compute Method '{compute_method}'...") + blended_embedding = np.mean(embeddings, axis=0) if compute_method == "Mean" else np.median(embeddings, axis=0) if compute_method == "Median" else stats.mode(embeddings, axis=0)[0].astype(np.float32) + blended_face = Face( + bbox=faces[0].bbox, + kps=faces[0].kps, + det_score=faces[0].det_score, + landmark_3d_68=faces[0].landmark_3d_68, + pose=faces[0].pose, + landmark_2d_106=faces[0].landmark_2d_106, + embedding=blended_embedding, + gender=faces[0].gender, + age=faces[0].age + ) + if blended_face is not None: + BLENDED_FACE_MODEL = blended_face + if save_mode: + face_model_path = os.path.join(FACE_MODELS_PATH, face_model_name + ".safetensors") + save_face_model(blended_face,face_model_path) + # done_msg = f"Face model has been saved to '{face_model_path}'" + # logger.status(done_msg) + logger.status("--Done!--") + # return (blended_face,) + else: + no_face_msg = "Something went wrong, please try another set of images" + logger.error(no_face_msg) + # return (blended_face,) + # logger.status("--Done!--") + if images is None and face_models is None: + logger.error("Please provide `images` or `face_models`") + return (blended_face,) + + +class SaveFaceModel: + def __init__(self): + self.output_dir = FACE_MODELS_PATH + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "save_mode": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}), + "face_model_name": ("STRING", {"default": "default"}), + "select_face_index": ("INT", {"default": 0, "min": 0}), + }, + "optional": { + "image": ("IMAGE",), + "face_model": ("FACE_MODEL",), + } + } + + RETURN_TYPES = () + FUNCTION = "save_model" + + OUTPUT_NODE = True + + CATEGORY = "🌌 ReActor" + + def save_model(self, save_mode, face_model_name, select_face_index, image=None, face_model=None, det_size=(640, 640)): + if save_mode and image is not None: + source = tensor_to_pil(image) + source = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR) + apply_patch(1) + logger.status("Building Face Model...") + face_model_raw = analyze_faces(source, det_size) + if len(face_model_raw) == 0: + det_size_half = half_det_size(det_size) + face_model_raw = analyze_faces(source, det_size_half) + try: + face_model = face_model_raw[select_face_index] + except: + logger.error("No face(s) found") + return face_model_name + logger.status("--Done!--") + if save_mode and (face_model != "none" or face_model is not None): + face_model_path = os.path.join(self.output_dir, face_model_name + ".safetensors") + save_face_model(face_model,face_model_path) + if image is None and face_model is None: + logger.error("Please provide `face_model` or `image`") + return face_model_name + + +class RestoreFace: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],), + "model": (get_model_names(get_restorers),), + "visibility": ("FLOAT", {"default": 1, "min": 0.0, "max": 1, "step": 0.05}), + "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + # def __init__(self): + # self.face_helper = None + # self.face_size = 512 + + def execute(self, image, model, visibility, codeformer_weight, facedetection): + result = reactor.restore_face(self,image,model,visibility,codeformer_weight,facedetection) + return (result,) + + +class MaskHelper: + def __init__(self): + # self.threshold = 0.5 + # self.dilation = 10 + # self.crop_factor = 3.0 + # self.drop_size = 1 + self.labels = "all" + self.detailer_hook = None + self.device_mode = "AUTO" + self.detection_hint = "center-1" + # self.sam_dilation = 0 + # self.sam_threshold = 0.93 + # self.bbox_expansion = 0 + # self.mask_hint_threshold = 0.7 + # self.mask_hint_use_negative = "False" + # self.force_resize_width = 0 + # self.force_resize_height = 0 + # self.resize_behavior = "source_size" + + @classmethod + def INPUT_TYPES(s): + bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("ultralytics_bbox")] + segms = ["segm/"+x for x in folder_paths.get_filename_list("ultralytics_segm")] + sam_models = [x for x in folder_paths.get_filename_list("sams") if 'hq' not in x] + return { + "required": { + "image": ("IMAGE",), + "swapped_image": ("IMAGE",), + "bbox_model_name": (bboxs + segms, ), + "bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "bbox_dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}), + "bbox_crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), + "bbox_drop_size": ("INT", {"min": 1, "max": 8192, "step": 1, "default": 10}), + "sam_model_name": (sam_models, ), + "sam_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), + "sam_threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}), + "bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), + "mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), + "mask_hint_use_negative": (["False", "Small", "Outter"], ), + "morphology_operation": (["dilate", "erode", "open", "close"],), + "morphology_distance": ("INT", {"default": 0, "min": 0, "max": 128, "step": 1}), + "blur_radius": ("INT", {"default": 9, "min": 0, "max": 48, "step": 1}), + "sigma_factor": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 3., "step": 0.01}), + }, + "optional": { + "mask_optional": ("MASK",), + } + } + + RETURN_TYPES = ("IMAGE","MASK","IMAGE","IMAGE") + RETURN_NAMES = ("IMAGE","MASK","MASK_PREVIEW","SWAPPED_FACE") + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def execute(self, image, swapped_image, bbox_model_name, bbox_threshold, bbox_dilation, bbox_crop_factor, bbox_drop_size, sam_model_name, sam_dilation, sam_threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative, morphology_operation, morphology_distance, blur_radius, sigma_factor, mask_optional=None): + + # images = [image[i:i + 1, ...] for i in range(image.shape[0])] + + images = image + + if mask_optional is None: + + bbox_model_path = folder_paths.get_full_path("ultralytics", bbox_model_name) + bbox_model = subcore.load_yolo(bbox_model_path) + bbox_detector = subcore.UltraBBoxDetector(bbox_model) + + segs = bbox_detector.detect(images, bbox_threshold, bbox_dilation, bbox_crop_factor, bbox_drop_size, self.detailer_hook) + + if isinstance(self.labels, list): + self.labels = str(self.labels[0]) + + if self.labels is not None and self.labels != '': + self.labels = self.labels.split(',') + if len(self.labels) > 0: + segs, _ = masking_segs.filter(segs, self.labels) + # segs, _ = masking_segs.filter(segs, "all") + + sam_modelname = folder_paths.get_full_path("sams", sam_model_name) + + if 'vit_h' in sam_model_name: + model_kind = 'vit_h' + elif 'vit_l' in sam_model_name: + model_kind = 'vit_l' + else: + model_kind = 'vit_b' + + sam = sam_model_registry[model_kind](checkpoint=sam_modelname) + size = os.path.getsize(sam_modelname) + sam.safe_to = core.SafeToGPU(size) + + device = model_management.get_torch_device() + + sam.safe_to.to_device(sam, device) + + sam.is_auto_mode = self.device_mode == "AUTO" + + combined_mask, _ = core.make_sam_mask_segmented(sam, segs, images, self.detection_hint, sam_dilation, sam_threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative) + + else: + combined_mask = mask_optional + + # *** MASK TO IMAGE ***: + + mask_image = combined_mask.reshape((-1, 1, combined_mask.shape[-2], combined_mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + + # *** MASK MORPH ***: + + mask_image = core.tensor2mask(mask_image) + + if morphology_operation == "dilate": + mask_image = self.dilate(mask_image, morphology_distance) + elif morphology_operation == "erode": + mask_image = self.erode(mask_image, morphology_distance) + elif morphology_operation == "open": + mask_image = self.erode(mask_image, morphology_distance) + mask_image = self.dilate(mask_image, morphology_distance) + elif morphology_operation == "close": + mask_image = self.dilate(mask_image, morphology_distance) + mask_image = self.erode(mask_image, morphology_distance) + + # *** MASK BLUR ***: + + if len(mask_image.size()) == 3: + mask_image = mask_image.unsqueeze(3) + + mask_image = mask_image.permute(0, 3, 1, 2) + kernel_size = blur_radius * 2 + 1 + sigma = sigma_factor * (0.6 * blur_radius - 0.3) + mask_image_final = self.gaussian_blur(mask_image, kernel_size, sigma).permute(0, 2, 3, 1) + if mask_image_final.size()[3] == 1: + mask_image_final = mask_image_final[:, :, :, 0] + + # *** CUT BY MASK ***: + + if len(swapped_image.shape) < 4: + C = 1 + else: + C = swapped_image.shape[3] + + # We operate on RGBA to keep the code clean and then convert back after + swapped_image = core.tensor2rgba(swapped_image) + mask = core.tensor2mask(mask_image_final) + + # Scale the mask to be a matching size if it isn't + B, H, W, _ = swapped_image.shape + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest')[:,0,:,:] + MB, _, _ = mask.shape + + if MB < B: + assert(B % MB == 0) + mask = mask.repeat(B // MB, 1, 1) + + # masks_to_boxes errors if the tensor is all zeros, so we'll add a single pixel and zero it out at the end + is_empty = ~torch.gt(torch.max(torch.reshape(mask,[MB, H * W]), dim=1).values, 0.) + mask[is_empty,0,0] = 1. + boxes = masks_to_boxes(mask) + mask[is_empty,0,0] = 0. + + min_x = boxes[:,0] + min_y = boxes[:,1] + max_x = boxes[:,2] + max_y = boxes[:,3] + + width = max_x - min_x + 1 + height = max_y - min_y + 1 + + use_width = int(torch.max(width).item()) + use_height = int(torch.max(height).item()) + + # if self.force_resize_width > 0: + # use_width = self.force_resize_width + + # if self.force_resize_height > 0: + # use_height = self.force_resize_height + + alpha_mask = torch.ones((B, H, W, 4)) + alpha_mask[:,:,:,3] = mask + + swapped_image = swapped_image * alpha_mask + + cutted_image = torch.zeros((B, use_height, use_width, 4)) + for i in range(0, B): + if not is_empty[i]: + ymin = int(min_y[i].item()) + ymax = int(max_y[i].item()) + xmin = int(min_x[i].item()) + xmax = int(max_x[i].item()) + single = (swapped_image[i, ymin:ymax+1, xmin:xmax+1,:]).unsqueeze(0) + resized = torch.nn.functional.interpolate(single.permute(0, 3, 1, 2), size=(use_height, use_width), mode='bicubic').permute(0, 2, 3, 1) + cutted_image[i] = resized[0] + + # Preserve our type unless we were previously RGB and added non-opaque alpha due to the mask size + if C == 1: + cutted_image = core.tensor2mask(cutted_image) + elif C == 3 and torch.min(cutted_image[:,:,:,3]) == 1: + cutted_image = core.tensor2rgb(cutted_image) + + # *** PASTE BY MASK ***: + + image_base = core.tensor2rgba(images) + image_to_paste = core.tensor2rgba(cutted_image) + mask = core.tensor2mask(mask_image_final) + + # Scale the mask to be a matching size if it isn't + B, H, W, C = image_base.shape + MB = mask.shape[0] + PB = image_to_paste.shape[0] + + if B < PB: + assert(PB % B == 0) + image_base = image_base.repeat(PB // B, 1, 1, 1) + B, H, W, C = image_base.shape + if MB < B: + assert(B % MB == 0) + mask = mask.repeat(B // MB, 1, 1) + elif B < MB: + assert(MB % B == 0) + image_base = image_base.repeat(MB // B, 1, 1, 1) + if PB < B: + assert(B % PB == 0) + image_to_paste = image_to_paste.repeat(B // PB, 1, 1, 1) + + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest')[:,0,:,:] + MB, MH, MW = mask.shape + + # masks_to_boxes errors if the tensor is all zeros, so we'll add a single pixel and zero it out at the end + is_empty = ~torch.gt(torch.max(torch.reshape(mask,[MB, MH * MW]), dim=1).values, 0.) + mask[is_empty,0,0] = 1. + boxes = masks_to_boxes(mask) + mask[is_empty,0,0] = 0. + + min_x = boxes[:,0] + min_y = boxes[:,1] + max_x = boxes[:,2] + max_y = boxes[:,3] + mid_x = (min_x + max_x) / 2 + mid_y = (min_y + max_y) / 2 + + target_width = max_x - min_x + 1 + target_height = max_y - min_y + 1 + + result = image_base.detach().clone() + face_segment = mask_image_final + + for i in range(0, MB): + if is_empty[i]: + continue + else: + image_index = i + source_size = image_to_paste.size() + SB, SH, SW, _ = image_to_paste.shape + + # Figure out the desired size + width = int(target_width[i].item()) + height = int(target_height[i].item()) + # if self.resize_behavior == "keep_ratio_fill": + # target_ratio = width / height + # actual_ratio = SW / SH + # if actual_ratio > target_ratio: + # width = int(height * actual_ratio) + # elif actual_ratio < target_ratio: + # height = int(width / actual_ratio) + # elif self.resize_behavior == "keep_ratio_fit": + # target_ratio = width / height + # actual_ratio = SW / SH + # if actual_ratio > target_ratio: + # height = int(width / actual_ratio) + # elif actual_ratio < target_ratio: + # width = int(height * actual_ratio) + # elif self.resize_behavior == "source_size" or self.resize_behavior == "source_size_unmasked": + + width = SW + height = SH + + # Resize the image we're pasting if needed + resized_image = image_to_paste[i].unsqueeze(0) + # if SH != height or SW != width: + # resized_image = torch.nn.functional.interpolate(resized_image.permute(0, 3, 1, 2), size=(height,width), mode='bicubic').permute(0, 2, 3, 1) + + pasting = torch.ones([H, W, C]) + ymid = float(mid_y[i].item()) + ymin = int(math.floor(ymid - height / 2)) + 1 + ymax = int(math.floor(ymid + height / 2)) + 1 + xmid = float(mid_x[i].item()) + xmin = int(math.floor(xmid - width / 2)) + 1 + xmax = int(math.floor(xmid + width / 2)) + 1 + + _, source_ymax, source_xmax, _ = resized_image.shape + source_ymin, source_xmin = 0, 0 + + if xmin < 0: + source_xmin = abs(xmin) + xmin = 0 + if ymin < 0: + source_ymin = abs(ymin) + ymin = 0 + if xmax > W: + source_xmax -= (xmax - W) + xmax = W + if ymax > H: + source_ymax -= (ymax - H) + ymax = H + + pasting[ymin:ymax, xmin:xmax, :] = resized_image[0, source_ymin:source_ymax, source_xmin:source_xmax, :] + pasting[:, :, 3] = 1. + + pasting_alpha = torch.zeros([H, W]) + pasting_alpha[ymin:ymax, xmin:xmax] = resized_image[0, source_ymin:source_ymax, source_xmin:source_xmax, 3] + + # if self.resize_behavior == "keep_ratio_fill" or self.resize_behavior == "source_size_unmasked": + # # If we explicitly want to fill the area, we are ok with extending outside + # paste_mask = pasting_alpha.unsqueeze(2).repeat(1, 1, 4) + # else: + # paste_mask = torch.min(pasting_alpha, mask[i]).unsqueeze(2).repeat(1, 1, 4) + paste_mask = torch.min(pasting_alpha, mask[i]).unsqueeze(2).repeat(1, 1, 4) + result[image_index] = pasting * paste_mask + result[image_index] * (1. - paste_mask) + + face_segment = result + + face_segment[...,3] = mask[i] + + result = rgba2rgb_tensor(result) + + return (result,combined_mask,mask_image_final,face_segment,) + + def gaussian_blur(self, image, kernel_size, sigma): + kernel = torch.Tensor(kernel_size, kernel_size).to(device=image.device) + center = kernel_size // 2 + variance = sigma**2 + for i in range(kernel_size): + for j in range(kernel_size): + x = i - center + y = j - center + kernel[i, j] = math.exp(-(x**2 + y**2)/(2*variance)) + kernel /= kernel.sum() + + # Pad the input tensor + padding = (kernel_size - 1) // 2 + input_pad = torch.nn.functional.pad(image, (padding, padding, padding, padding), mode='reflect') + + # Reshape the padded input tensor for batched convolution + batch_size, num_channels, height, width = image.shape + input_reshaped = input_pad.reshape(batch_size*num_channels, 1, height+padding*2, width+padding*2) + + # Perform batched convolution with the Gaussian kernel + output_reshaped = torch.nn.functional.conv2d(input_reshaped, kernel.unsqueeze(0).unsqueeze(0)) + + # Reshape the output tensor to its original shape + output_tensor = output_reshaped.reshape(batch_size, num_channels, height, width) + + return output_tensor + + def erode(self, image, distance): + return 1. - self.dilate(1. - image, distance) + + def dilate(self, image, distance): + kernel_size = 1 + distance * 2 + # Add the channels dimension + image = image.unsqueeze(1) + out = torchfn.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=kernel_size // 2).squeeze(1) + return out + + +class ImageDublicator: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "count": ("INT", {"default": 1, "min": 0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("IMAGES",) + OUTPUT_IS_LIST = (True,) + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def execute(self, image, count): + images = [image for i in range(count)] + return (images,) + + +class ImageRGBA2RGB: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def execute(self, image): + out = rgba2rgb_tensor(image) + return (out,) + + +class MakeFaceModelBatch: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "face_model1": ("FACE_MODEL",), + }, + "optional": { + "face_model2": ("FACE_MODEL",), + "face_model3": ("FACE_MODEL",), + "face_model4": ("FACE_MODEL",), + "face_model5": ("FACE_MODEL",), + "face_model6": ("FACE_MODEL",), + "face_model7": ("FACE_MODEL",), + "face_model8": ("FACE_MODEL",), + "face_model9": ("FACE_MODEL",), + "face_model10": ("FACE_MODEL",), + }, + } + + RETURN_TYPES = ("FACE_MODEL",) + RETURN_NAMES = ("FACE_MODELS",) + FUNCTION = "execute" + + CATEGORY = "🌌 ReActor" + + def execute(self, **kwargs): + if len(kwargs) > 0: + face_models = [value for value in kwargs.values()] + return (face_models,) + else: + logger.error("Please provide at least 1 `face_model`") + return (None,) + + +class ReActorOptions: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_faces_order": ( + ["left-right","right-left","top-bottom","bottom-top","small-large","large-small"], {"default": "large-small"} + ), + "input_faces_index": ("STRING", {"default": "0"}), + "detect_gender_input": (["no","female","male"], {"default": "no"}), + "source_faces_order": ( + ["left-right","right-left","top-bottom","bottom-top","small-large","large-small"], {"default": "large-small"} + ), + "source_faces_index": ("STRING", {"default": "0"}), + "detect_gender_source": (["no","female","male"], {"default": "no"}), + "console_log_level": ([0, 1, 2], {"default": 1}), + } + } + + RETURN_TYPES = ("OPTIONS",) + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def execute(self,input_faces_order, input_faces_index, detect_gender_input, source_faces_order, source_faces_index, detect_gender_source, console_log_level): + options: dict = { + "input_faces_order": input_faces_order, + "input_faces_index": input_faces_index, + "detect_gender_input": detect_gender_input, + "source_faces_order": source_faces_order, + "source_faces_index": source_faces_index, + "detect_gender_source": detect_gender_source, + "console_log_level": console_log_level, + } + return (options, ) + + +class ReActorFaceBoost: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}), + "boost_model": (get_model_names(get_restorers),), + "interpolation": (["Nearest","Bilinear","Bicubic","Lanczos"], {"default": "Bicubic"}), + "visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}), + "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}), + "restore_with_main_after": ("BOOLEAN", {"default": False}), + } + } + + RETURN_TYPES = ("FACE_BOOST",) + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def execute(self,enabled,boost_model,interpolation,visibility,codeformer_weight,restore_with_main_after): + face_boost: dict = { + "enabled": enabled, + "boost_model": boost_model, + "interpolation": interpolation, + "visibility": visibility, + "codeformer_weight": codeformer_weight, + "restore_with_main_after": restore_with_main_after, + } + return (face_boost, ) + +class ReActorUnload: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "trigger": ("IMAGE", ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + CATEGORY = "🌌 ReActor" + + def execute(self, trigger): + unload_all_models() + return (trigger,) + + +NODE_CLASS_MAPPINGS = { + # --- MAIN NODES --- + "ReActorFaceSwap": reactor, + "ReActorFaceSwapOpt": ReActorPlusOpt, + "ReActorOptions": ReActorOptions, + "ReActorFaceBoost": ReActorFaceBoost, + "ReActorMaskHelper": MaskHelper, + "ReActorSetWeight": ReActorWeight, + # --- Operations with Face Models --- + "ReActorSaveFaceModel": SaveFaceModel, + "ReActorLoadFaceModel": LoadFaceModel, + "ReActorBuildFaceModel": BuildFaceModel, + "ReActorMakeFaceModelBatch": MakeFaceModelBatch, + # --- Additional Nodes --- + "ReActorRestoreFace": RestoreFace, + "ReActorImageDublicator": ImageDublicator, + "ImageRGBA2RGB": ImageRGBA2RGB, + "ReActorUnload": ReActorUnload, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # --- MAIN NODES --- + "ReActorFaceSwap": "ReActor 🌌 Fast Face Swap", + "ReActorFaceSwapOpt": "ReActor 🌌 Fast Face Swap [OPTIONS]", + "ReActorOptions": "ReActor 🌌 Options", + "ReActorFaceBoost": "ReActor 🌌 Face Booster", + "ReActorMaskHelper": "ReActor 🌌 Masking Helper", + "ReActorSetWeight": "ReActor 🌌 Set Face Swap Weight", + # --- Operations with Face Models --- + "ReActorSaveFaceModel": "Save Face Model 🌌 ReActor", + "ReActorLoadFaceModel": "Load Face Model 🌌 ReActor", + "ReActorBuildFaceModel": "Build Blended Face Model 🌌 ReActor", + "ReActorMakeFaceModelBatch": "Make Face Model Batch 🌌 ReActor", + # --- Additional Nodes --- + "ReActorRestoreFace": "Restore Face 🌌 ReActor", + "ReActorImageDublicator": "Image Dublicator (List) 🌌 ReActor", + "ImageRGBA2RGB": "Convert RGBA to RGB 🌌 ReActor", + "ReActorUnload": "Unload ReActor Models 🌌 ReActor", +} diff --git a/custom_nodes/ComfyUI-ReActor/pyproject.toml b/custom_nodes/ComfyUI-ReActor/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..384d9327048a1d8a9cfc0f46aa8702cb8b3ba4e3 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "comfyui-reactor" +description = "(SFW-Friendly) The Fast and Simple Face Swap Extension Node for ComfyUI, based on ReActor SD-WebUI Face Swap Extension" +version = "0.6.0-a1" +license = { file = "LICENSE" } +dependencies = ["insightface==0.7.3", "onnx>=1.14.0", "opencv-python>=4.7.0.72", "numpy==1.26.3", "segment_anything", "albumentations>=1.4.16", "ultralytics"] + +[project.urls] +Repository = "https://github.com/Gourieff/ComfyUI-ReActor" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "gourieff" +DisplayName = "ComfyUI-ReActor" +Icon = "" diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..871b6366a986e7a816a5a0dd0ca900b3ca4450c1 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py @@ -0,0 +1,12 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .test import * +from .train import * +from .utils import * +from .version import __gitsha__, __version__ diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94c531f5bfb143b4ca735a60ad3f3307a6d14978 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'r_basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fc628abb70d95bfc1c64906006307f1258417ffd --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py @@ -0,0 +1,322 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +try: + from distutils.version import LooseVersion +except: + from packaging.version import Version + LooseVersion = Version +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from r_basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from r_basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..c438d3513807e4b164ffaa9c843bfa0e6733ba44 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py @@ -0,0 +1,336 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import ResidualBlockNoBN, flow_warp, make_layer +from .edvr_arch import PCDAlignment, TSAFusion +from .spynet_arch import SpyNet + + +@ARCH_REGISTRY.register() +class BasicVSR(nn.Module): + """A recurrent network for video SR. Now only x4 is supported. + + Args: + num_feat (int): Number of channels. Default: 64. + num_block (int): Number of residual blocks for each branch. Default: 15 + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + """ + + def __init__(self, num_feat=64, num_block=15, spynet_path=None): + super().__init__() + self.num_feat = num_feat + + # alignment + self.spynet = SpyNet(spynet_path) + + # propagation + self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) + self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) + + # reconstruction + self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True) + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + self.pixel_shuffle = nn.PixelShuffle(2) + + # activation functions + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def get_flow(self, x): + b, n, c, h, w = x.size() + + x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) + x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w) + flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w) + + return flows_forward, flows_backward + + def forward(self, x): + """Forward function of BasicVSR. + + Args: + x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames. + """ + flows_forward, flows_backward = self.get_flow(x) + b, n, _, h, w = x.size() + + # backward branch + out_l = [] + feat_prop = x.new_zeros(b, self.num_feat, h, w) + for i in range(n - 1, -1, -1): + x_i = x[:, i, :, :, :] + if i < n - 1: + flow = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + feat_prop = torch.cat([x_i, feat_prop], dim=1) + feat_prop = self.backward_trunk(feat_prop) + out_l.insert(0, feat_prop) + + # forward branch + feat_prop = torch.zeros_like(feat_prop) + for i in range(0, n): + x_i = x[:, i, :, :, :] + if i > 0: + flow = flows_forward[:, i - 1, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + + feat_prop = torch.cat([x_i, feat_prop], dim=1) + feat_prop = self.forward_trunk(feat_prop) + + # upsample + out = torch.cat([out_l[i], feat_prop], dim=1) + out = self.lrelu(self.fusion(out)) + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False) + out += base + out_l[i] = out + + return torch.stack(out_l, dim=1) + + +class ConvResidualBlocks(nn.Module): + """Conv and residual block used in BasicVSR. + + Args: + num_in_ch (int): Number of input channels. Default: 3. + num_out_ch (int): Number of output channels. Default: 64. + num_block (int): Number of residual blocks. Default: 15. + """ + + def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15): + super().__init__() + self.main = nn.Sequential( + nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), + make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch)) + + def forward(self, fea): + return self.main(fea) + + +@ARCH_REGISTRY.register() +class IconVSR(nn.Module): + """IconVSR, proposed also in the BasicVSR paper. + + Args: + num_feat (int): Number of channels. Default: 64. + num_block (int): Number of residual blocks for each branch. Default: 15. + keyframe_stride (int): Keyframe stride. Default: 5. + temporal_padding (int): Temporal padding. Default: 2. + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + edvr_path (str): Path to the pretrained EDVR model. Default: None. + """ + + def __init__(self, + num_feat=64, + num_block=15, + keyframe_stride=5, + temporal_padding=2, + spynet_path=None, + edvr_path=None): + super().__init__() + + self.num_feat = num_feat + self.temporal_padding = temporal_padding + self.keyframe_stride = keyframe_stride + + # keyframe_branch + self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path) + # alignment + self.spynet = SpyNet(spynet_path) + + # propagation + self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True) + self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) + + self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True) + self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block) + + # reconstruction + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + self.pixel_shuffle = nn.PixelShuffle(2) + + # activation functions + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def pad_spatial(self, x): + """Apply padding spatially. + + Since the PCD module in EDVR requires that the resolution is a multiple + of 4, we apply padding to the input LR images if their resolution is + not divisible by 4. + + Args: + x (Tensor): Input LR sequence with shape (n, t, c, h, w). + Returns: + Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad). + """ + n, t, c, h, w = x.size() + + pad_h = (4 - h % 4) % 4 + pad_w = (4 - w % 4) % 4 + + # padding + x = x.view(-1, c, h, w) + x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect') + + return x.view(n, t, c, h + pad_h, w + pad_w) + + def get_flow(self, x): + b, n, c, h, w = x.size() + + x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) + x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w) + flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w) + + return flows_forward, flows_backward + + def get_keyframe_feature(self, x, keyframe_idx): + if self.temporal_padding == 2: + x = [x[:, [4, 3]], x, x[:, [-4, -5]]] + elif self.temporal_padding == 3: + x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]] + x = torch.cat(x, dim=1) + + num_frames = 2 * self.temporal_padding + 1 + feats_keyframe = {} + for i in keyframe_idx: + feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous()) + return feats_keyframe + + def forward(self, x): + b, n, _, h_input, w_input = x.size() + + x = self.pad_spatial(x) + h, w = x.shape[3:] + + keyframe_idx = list(range(0, n, self.keyframe_stride)) + if keyframe_idx[-1] != n - 1: + keyframe_idx.append(n - 1) # last frame is a keyframe + + # compute flow and keyframe features + flows_forward, flows_backward = self.get_flow(x) + feats_keyframe = self.get_keyframe_feature(x, keyframe_idx) + + # backward branch + out_l = [] + feat_prop = x.new_zeros(b, self.num_feat, h, w) + for i in range(n - 1, -1, -1): + x_i = x[:, i, :, :, :] + if i < n - 1: + flow = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + if i in keyframe_idx: + feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1) + feat_prop = self.backward_fusion(feat_prop) + feat_prop = torch.cat([x_i, feat_prop], dim=1) + feat_prop = self.backward_trunk(feat_prop) + out_l.insert(0, feat_prop) + + # forward branch + feat_prop = torch.zeros_like(feat_prop) + for i in range(0, n): + x_i = x[:, i, :, :, :] + if i > 0: + flow = flows_forward[:, i - 1, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + if i in keyframe_idx: + feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1) + feat_prop = self.forward_fusion(feat_prop) + + feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1) + feat_prop = self.forward_trunk(feat_prop) + + # upsample + out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False) + out += base + out_l[i] = out + + return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input] + + +class EDVRFeatureExtractor(nn.Module): + """EDVR feature extractor used in IconVSR. + + Args: + num_input_frame (int): Number of input frames. + num_feat (int): Number of feature channels + load_path (str): Path to the pretrained weights of EDVR. Default: None. + """ + + def __init__(self, num_input_frame, num_feat, load_path): + + super(EDVRFeatureExtractor, self).__init__() + + self.center_frame_idx = num_input_frame // 2 + + # extract pyramid features + self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1) + self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat) + self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + # pcd and tsa module + self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8) + self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + def forward(self, x): + b, n, c, h, w = x.size() + + # extract features for each frame + # L1 + feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) + feat_l1 = self.feature_extraction(feat_l1) + # L2 + feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) + feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) + # L3 + feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) + feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) + + feat_l1 = feat_l1.view(b, n, -1, h, w) + feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2) + feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4) + + # PCD alignment + ref_feat_l = [ # reference feature list + feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), + feat_l3[:, self.center_frame_idx, :, :, :].clone() + ] + aligned_feat = [] + for i in range(n): + nbr_feat_l = [ # neighboring feature list + feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() + ] + aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) + aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) + + # TSA fusion + return self.fusion(aligned_feat) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9d396dddac81fc4d03a6d97dc46faab83a4bab --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py @@ -0,0 +1,407 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import warnings + +from r_basicsr.archs.arch_util import flow_warp +from r_basicsr.archs.basicvsr_arch import ConvResidualBlocks +from r_basicsr.archs.spynet_arch import SpyNet +from r_basicsr.ops.dcn import ModulatedDeformConvPack +from r_basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class BasicVSRPlusPlus(nn.Module): + """BasicVSR++ network structure. + Support either x4 upsampling or same size output. Since DCN is used in this + model, it can only be used with CUDA enabled. If CUDA is not enabled, + feature alignment will be skipped. Besides, we adopt the official DCN + implementation and the version of torch need to be higher than 1.9. + Paper: + BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation + and Alignment + Args: + mid_channels (int, optional): Channel number of the intermediate + features. Default: 64. + num_blocks (int, optional): The number of residual blocks in each + propagation branch. Default: 7. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + is_low_res_input (bool, optional): Whether the input is low-resolution + or not. If False, the output resolution is equal to the input + resolution. Default: True. + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + cpu_cache_length (int, optional): When the length of sequence is larger + than this value, the intermediate features are sent to CPU. This + saves GPU memory, but slows down the inference speed. You can + increase this number if you have a GPU with large memory. + Default: 100. + """ + + def __init__(self, + mid_channels=64, + num_blocks=7, + max_residue_magnitude=10, + is_low_res_input=True, + spynet_path=None, + cpu_cache_length=100): + + super().__init__() + self.mid_channels = mid_channels + self.is_low_res_input = is_low_res_input + self.cpu_cache_length = cpu_cache_length + + # optical flow + self.spynet = SpyNet(spynet_path) + + # feature extraction module + if is_low_res_input: + self.feat_extract = ConvResidualBlocks(3, mid_channels, 5) + else: + self.feat_extract = nn.Sequential( + nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True), + ConvResidualBlocks(mid_channels, mid_channels, 5)) + + # propagation branches + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + for i, module in enumerate(modules): + if torch.cuda.is_available(): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * mid_channels, + mid_channels, + 3, + padding=1, + deformable_groups=16, + max_residue_magnitude=max_residue_magnitude) + self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks) + + # upsampling module + self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5) + + self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True) + + self.pixel_shuffle = nn.PixelShuffle(2) + + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # check if the sequence is augmented by flipping + self.is_mirror_extended = False + + if len(self.deform_align) > 0: + self.is_with_alignment = True + else: + self.is_with_alignment = False + warnings.warn('Deformable alignment module is not added. ' + 'Probably your CUDA is not configured correctly. DCN can only ' + 'be used with CUDA enabled. Alignment is skipped now.') + + def check_if_mirror_extended(self, lqs): + """Check whether the input is a mirror-extended sequence. + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the + (t-1-i)-th frame. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + """ + + if lqs.size(1) % 2 == 0: + lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1) + if torch.norm(lqs_1 - lqs_2.flip(1)) == 0: + self.is_mirror_extended = True + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + + if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) + flows_forward = flows_backward.flip(1) + else: + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + if self.cpu_cache: + flows_backward = flows_backward.cpu() + flows_forward = flows_forward.cpu() + + return flows_forward, flows_backward + + def propagate(self, feats, flows, module_name): + """Propagate the latent features throughout the sequence. + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + Return: + dict(list[tensor]): A dictionary containing all the propagated + features. Each key in the dictionary corresponds to a + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + + frame_idx = range(0, t + 1) + flow_idx = range(-1, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + flow_idx = frame_idx + + feat_prop = flows.new_zeros(n, self.mid_channels, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + if self.cpu_cache: + feat_current = feat_current.cuda() + feat_prop = feat_prop.cuda() + # second-order deformable alignment + if i > 0 and self.is_with_alignment: + flow_n1 = flows[:, flow_idx[i], :, :, :] + if self.cpu_cache: + flow_n1 = flow_n1.cuda() + + cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + flow_n2 = torch.zeros_like(flow_n1) + cond_n2 = torch.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats[module_name][-2] + if self.cpu_cache: + feat_n2 = feat_n2.cuda() + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + if self.cpu_cache: + flow_n2 = flow_n2.cuda() + + flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1)) + cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1)) + + # flow-guided deformable convolution + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop] + if self.cpu_cache: + feat = [f.cuda() for f in feat] + + feat = torch.cat(feat, dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + feats[module_name].append(feat_prop) + + if self.cpu_cache: + feats[module_name][-1] = feats[module_name][-1].cpu() + torch.cuda.empty_cache() + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + return feats + + def upsample(self, lqs, feats): + """Compute the output image given the features. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + feats (dict): The features from the propgation branches. + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + outputs = [] + num_outputs = len(feats['spatial']) + + mapping_idx = list(range(0, num_outputs)) + mapping_idx += mapping_idx[::-1] + + for i in range(0, lqs.size(1)): + hr = [feats[k].pop(0) for k in feats if k != 'spatial'] + hr.insert(0, feats['spatial'][mapping_idx[i]]) + hr = torch.cat(hr, dim=1) + if self.cpu_cache: + hr = hr.cuda() + + hr = self.reconstruction(hr) + hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr))) + hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr))) + hr = self.lrelu(self.conv_hr(hr)) + hr = self.conv_last(hr) + if self.is_low_res_input: + hr += self.img_upsample(lqs[:, i, :, :, :]) + else: + hr += lqs[:, i, :, :, :] + + if self.cpu_cache: + hr = hr.cpu() + torch.cuda.empty_cache() + + outputs.append(hr) + + return torch.stack(outputs, dim=1) + + def forward(self, lqs): + """Forward function for BasicVSR++. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lqs.size() + + # whether to cache the features in CPU + self.cpu_cache = True if t > self.cpu_cache_length else False + + if self.is_low_res_input: + lqs_downsample = lqs.clone() + else: + lqs_downsample = F.interpolate( + lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4) + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lqs) + + feats = {} + # compute spatial features + if self.cpu_cache: + feats['spatial'] = [] + for i in range(0, t): + feat = self.feat_extract(lqs[:, i, :, :, :]).cpu() + feats['spatial'].append(feat) + torch.cuda.empty_cache() + else: + feats_ = self.feat_extract(lqs.view(-1, c, h, w)) + h, w = feats_.shape[2:] + feats_ = feats_.view(n, t, -1, h, w) + feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)] + + # compute optical flow using the low-res inputs + assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, ( + 'The height and width of low-res inputs must be at least 64, ' + f'but got {h} and {w}.') + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + # feature propgation + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + module = f'{direction}_{iter_}' + + feats[module] = [] + + if direction == 'backward': + flows = flows_backward + elif flows_forward is not None: + flows = flows_forward + else: + flows = flows_backward.flip(1) + + feats = self.propagate(feats, flows, module) + if self.cpu_cache: + del flows + torch.cuda.empty_cache() + + return self.upsample(lqs, feats) + + +class SecondOrderDeformableAlignment(ModulatedDeformConvPack): + """Second-order deformable alignment module. + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1), + ) + + self.init_offset() + + def init_offset(self): + + def _constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + _constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat, flow_1, flow_2): + extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1) + offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + + +# if __name__ == '__main__': +# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth' +# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda() +# input = torch.rand(1, 2, 3, 64, 64).cuda() +# output = model(input) +# print('===================') +# print(output.shape) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..646da689c5e8f88ce879ce5f735e6273d6724319 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.spectral_norm import spectral_norm + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization +from .vgg_arch import VGGFeatureExtractor + + +class SFTUpBlock(nn.Module): + """Spatial feature transform (SFT) with upsampling block. + + Args: + in_channel (int): Number of input channels. + out_channel (int): Number of output channels. + kernel_size (int): Kernel size in convolutions. Default: 3. + padding (int): Padding in convolutions. Default: 1. + """ + + def __init__(self, in_channel, out_channel, kernel_size=3, padding=1): + super(SFTUpBlock, self).__init__() + self.conv1 = nn.Sequential( + Blur(in_channel), + spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.04, True), + # The official codes use two LeakyReLU here, so 0.04 for equivalent + ) + self.convup = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2, True), + ) + + # for SFT scale and shift + self.scale_block = nn.Sequential( + spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), + spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))) + self.shift_block = nn.Sequential( + spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), + spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid()) + # The official codes use sigmoid for shift block, do not know why + + def forward(self, x, updated_feat): + out = self.conv1(x) + # SFT + scale = self.scale_block(updated_feat) + shift = self.shift_block(updated_feat) + out = out * scale + shift + # upsample + out = self.convup(out) + return out + + +@ARCH_REGISTRY.register() +class DFDNet(nn.Module): + """DFDNet: Deep Face Dictionary Network. + + It only processes faces with 512x512 size. + + Args: + num_feat (int): Number of feature channels. + dict_path (str): Path to the facial component dictionary. + """ + + def __init__(self, num_feat, dict_path): + super().__init__() + self.parts = ['left_eye', 'right_eye', 'nose', 'mouth'] + # part_sizes: [80, 80, 50, 110] + channel_sizes = [128, 256, 512, 512] + self.feature_sizes = np.array([256, 128, 64, 32]) + self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4'] + self.flag_dict_device = False + + # dict + self.dict = torch.load(dict_path) + + # vgg face extractor + self.vgg_extractor = VGGFeatureExtractor( + layer_name_list=self.vgg_layers, + vgg_type='vgg19', + use_input_norm=True, + range_norm=True, + requires_grad=False) + + # attention block for fusing dictionary features and input features + self.attn_blocks = nn.ModuleDict() + for idx, feat_size in enumerate(self.feature_sizes): + for name in self.parts: + self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx]) + + # multi scale dilation block + self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1]) + + # upsampling and reconstruction + self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8) + self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4) + self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2) + self.upsample3 = SFTUpBlock(num_feat * 2, num_feat) + self.upsample4 = nn.Sequential( + spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat), + UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh()) + + def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size): + """swap the features from the dictionary.""" + # get the original vgg features + part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone() + # resize original vgg features + part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False) + # use adaptive instance normalization to adjust color and illuminations + dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat) + # get similarity scores + similarity_score = F.conv2d(part_resize_feat, dict_feat) + similarity_score = F.softmax(similarity_score.view(-1), dim=0) + # select the most similar features in the dict (after norm) + select_idx = torch.argmax(similarity_score) + swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4]) + # attention + attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat) + attn_feat = attn * swap_feat + # update features + updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat + return updated_feat + + def put_dict_to_device(self, x): + if self.flag_dict_device is False: + for k, v in self.dict.items(): + for kk, vv in v.items(): + self.dict[k][kk] = vv.to(x) + self.flag_dict_device = True + + def forward(self, x, part_locations): + """ + Now only support testing with batch size = 0. + + Args: + x (Tensor): Input faces with shape (b, c, 512, 512). + part_locations (list[Tensor]): Part locations. + """ + self.put_dict_to_device(x) + # extract vggface features + vgg_features = self.vgg_extractor(x) + # update vggface features using the dictionary for each part + updated_vgg_features = [] + batch = 0 # only supports testing with batch size = 0 + for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes): + dict_features = self.dict[f'{f_size}'] + vgg_feat = vgg_features[vgg_layer] + updated_feat = vgg_feat.clone() + + # swap features from dictionary + for part_idx, part_name in enumerate(self.parts): + location = (part_locations[part_idx][batch] // (512 / f_size)).int() + updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name, + f_size) + + updated_vgg_features.append(updated_feat) + + vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4']) + # use updated vgg features to modulate the upsampled features with + # SFT (Spatial Feature Transform) scaling and shifting manner. + upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3]) + upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2]) + upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1]) + upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0]) + out = self.upsample4(upsampled_feat) + + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b4dc0ff738c76852e830b32fffbe65bffb5ddf50 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.utils.spectral_norm import spectral_norm + + +class BlurFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) + return grad_input + + @staticmethod + def backward(ctx, gradgrad_output): + kernel, _ = ctx.saved_tensors + grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]) + return grad_input, None, None + + +class BlurFunction(Function): + + @staticmethod + def forward(ctx, x, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) + return output + + @staticmethod + def backward(ctx, grad_output): + kernel, kernel_flip = ctx.saved_tensors + grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) + return grad_input, None, None + + +blur = BlurFunction.apply + + +class Blur(nn.Module): + + def __init__(self, channel): + super().__init__() + kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) + kernel = kernel.view(1, 1, 3, 3) + kernel = kernel / kernel.sum() + kernel_flip = torch.flip(kernel, [2, 3]) + + self.kernel = kernel.repeat(channel, 1, 1, 1) + self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) + + def forward(self, x): + return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + n, c = size[:2] + feat_var = feat.view(n, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(n, c, 1, 1) + feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def AttentionBlock(in_channel): + return nn.Sequential( + spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), + spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) + + +def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): + """Conv block used in MSDilationBlock.""" + + return nn.Sequential( + spectral_norm( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=((kernel_size - 1) // 2) * dilation, + bias=bias)), + nn.LeakyReLU(0.2), + spectral_norm( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=((kernel_size - 1) // 2) * dilation, + bias=bias)), + ) + + +class MSDilationBlock(nn.Module): + """Multi-scale dilation block.""" + + def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True): + super(MSDilationBlock, self).__init__() + + self.conv_blocks = nn.ModuleList() + for i in range(4): + self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias)) + self.conv_fusion = spectral_norm( + nn.Conv2d( + in_channels * 4, + in_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias)) + + def forward(self, x): + out = [] + for i in range(4): + out.append(self.conv_blocks[i](x)) + out = torch.cat(out, 1) + out = self.conv_fusion(out) + x + return out + + +class UpResBlock(nn.Module): + + def __init__(self, in_channel): + super(UpResBlock, self).__init__() + self.body = nn.Sequential( + nn.Conv2d(in_channel, in_channel, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(in_channel, in_channel, 3, 1, 1), + ) + + def forward(self, x): + out = x + self.body(x) + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5ae93aaffc749bc00f27947a6b67d2ae327109 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py @@ -0,0 +1,150 @@ +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm + +from r_basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class VGGStyleDiscriminator(nn.Module): + """VGG style discriminator with input size 128 x 128 or 256 x 256. + + It is used to train SRGAN, ESRGAN, and VideoGAN. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features.Default: 64. + """ + + def __init__(self, num_in_ch, num_feat, input_size=128): + super(VGGStyleDiscriminator, self).__init__() + self.input_size = input_size + assert self.input_size == 128 or self.input_size == 256, ( + f'input size must be 128 or 256, but received {input_size}') + + self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True) + + self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True) + self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True) + + self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True) + self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True) + + self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True) + self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True) + + self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True) + self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True) + + if self.input_size == 256: + self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) + self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True) + self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) + self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True) + + self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.') + + feat = self.lrelu(self.conv0_0(x)) + feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2 + + feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) + feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4 + + feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) + feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8 + + feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) + feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16 + + feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) + feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32 + + if self.input_size == 256: + feat = self.lrelu(self.bn5_0(self.conv5_0(feat))) + feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64 + + # spatial size: (4, 4) + feat = feat.view(feat.size(0), -1) + feat = self.lrelu(self.linear1(feat)) + out = self.linear2(feat) + return out + + +@ARCH_REGISTRY.register(suffix='basicsr') +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9dec88bf4f183e807c22a4e103845defc7228c --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py @@ -0,0 +1,277 @@ +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY + + +class DenseBlocksTemporalReduce(nn.Module): + """A concatenation of 3 dense blocks with reduction in temporal dimension. + + Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks. + + Args: + num_feat (int): Number of channels in the blocks. Default: 64. + num_grow_ch (int): Growing factor of the dense blocks. Default: 32 + adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation. + Set to false if you want to train from scratch. Default: False. + """ + + def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False): + super(DenseBlocksTemporalReduce, self).__init__() + if adapt_official_weights: + eps = 1e-3 + momentum = 1e-3 + else: # pytorch default values + eps = 1e-05 + momentum = 0.1 + + self.temporal_reduce1 = nn.Sequential( + nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True), + nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)) + + self.temporal_reduce2 = nn.Sequential( + nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + num_grow_ch, + num_feat + num_grow_ch, (1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)) + + self.temporal_reduce3 = nn.Sequential( + nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + 2 * num_grow_ch, + num_feat + 2 * num_grow_ch, (1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), + nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)) + + def forward(self, x): + """ + Args: + x (Tensor): Input tensor with shape (b, num_feat, t, h, w). + + Returns: + Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w). + """ + x1 = self.temporal_reduce1(x) + x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1) + + x2 = self.temporal_reduce2(x1) + x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1) + + x3 = self.temporal_reduce3(x2) + x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1) + + return x3 + + +class DenseBlocks(nn.Module): + """ A concatenation of N dense blocks. + + Args: + num_feat (int): Number of channels in the blocks. Default: 64. + num_grow_ch (int): Growing factor of the dense blocks. Default: 32. + num_block (int): Number of dense blocks. The values are: + DUF-S (16 layers): 3 + DUF-M (18 layers): 9 + DUF-L (52 layers): 21 + adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation. + Set to false if you want to train from scratch. Default: False. + """ + + def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False): + super(DenseBlocks, self).__init__() + if adapt_official_weights: + eps = 1e-3 + momentum = 1e-3 + else: # pytorch default values + eps = 1e-05 + momentum = 0.1 + + self.dense_blocks = nn.ModuleList() + for i in range(0, num_block): + self.dense_blocks.append( + nn.Sequential( + nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + i * num_grow_ch, + num_feat + i * num_grow_ch, (1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), + nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + i * num_grow_ch, + num_grow_ch, (3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + bias=True))) + + def forward(self, x): + """ + Args: + x (Tensor): Input tensor with shape (b, num_feat, t, h, w). + + Returns: + Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w). + """ + for i in range(0, len(self.dense_blocks)): + y = self.dense_blocks[i](x) + x = torch.cat((x, y), 1) + return x + + +class DynamicUpsamplingFilter(nn.Module): + """Dynamic upsampling filter used in DUF. + + Ref: https://github.com/yhjo09/VSR-DUF. + It only supports input with 3 channels. And it applies the same filters to 3 channels. + + Args: + filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5). + """ + + def __init__(self, filter_size=(5, 5)): + super(DynamicUpsamplingFilter, self).__init__() + if not isinstance(filter_size, tuple): + raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}') + if len(filter_size) != 2: + raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.') + # generate a local expansion filter, similar to im2col + self.filter_size = filter_size + filter_prod = np.prod(filter_size) + expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw) + self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels + + def forward(self, x, filters): + """Forward function for DynamicUpsamplingFilter. + + Args: + x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w). + filters (Tensor): Generated dynamic filters. + The shape is (n, filter_prod, upsampling_square, h, w). + filter_prod: prod of filter kernel size, e.g., 1*5*5=25. + upsampling_square: similar to pixel shuffle, + upsampling_square = upsampling * upsampling + e.g., for x 4 upsampling, upsampling_square= 4*4 = 16 + + Returns: + Tensor: Filtered image with shape (n, 3*upsampling_square, h, w) + """ + n, filter_prod, upsampling_square, h, w = filters.size() + kh, kw = self.filter_size + expanded_input = F.conv2d( + x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w) + expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1, + 2) # (n, h, w, 3, filter_prod) + filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square] + out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square) + return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w) + + +@ARCH_REGISTRY.register() +class DUF(nn.Module): + """Network architecture for DUF + + Paper: Jo et.al. Deep Video Super-Resolution Network Using Dynamic + Upsampling Filters Without Explicit Motion Compensation, CVPR, 2018 + Code reference: + https://github.com/yhjo09/VSR-DUF + For all the models below, 'adapt_official_weights' is only necessary when + loading the weights converted from the official TensorFlow weights. + Please set it to False if you are training the model from scratch. + + There are three models with different model size: DUF16Layers, DUF28Layers, + and DUF52Layers. This class is the base class for these models. + + Args: + scale (int): The upsampling factor. Default: 4. + num_layer (int): The number of layers. Default: 52. + adapt_official_weights_weights (bool): Whether to adapt the weights + translated from the official implementation. Set to false if you + want to train from scratch. Default: False. + """ + + def __init__(self, scale=4, num_layer=52, adapt_official_weights=False): + super(DUF, self).__init__() + self.scale = scale + if adapt_official_weights: + eps = 1e-3 + momentum = 1e-3 + else: # pytorch default values + eps = 1e-05 + momentum = 0.1 + + self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + self.dynamic_filter = DynamicUpsamplingFilter((5, 5)) + + if num_layer == 16: + num_block = 3 + num_grow_ch = 32 + elif num_layer == 28: + num_block = 9 + num_grow_ch = 16 + elif num_layer == 52: + num_block = 21 + num_grow_ch = 16 + else: + raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.') + + self.dense_block1 = DenseBlocks( + num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch, + adapt_official_weights=adapt_official_weights) # T = 7 + self.dense_block2 = DenseBlocksTemporalReduce( + 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1 + channels = 64 + num_grow_ch * num_block + num_grow_ch * 3 + self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum) + self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + + self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + + self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + self.conv3d_f2 = nn.Conv3d( + 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + + def forward(self, x): + """ + Args: + x (Tensor): Input with shape (b, 7, c, h, w) + + Returns: + Tensor: Output with shape (b, c, h * scale, w * scale) + """ + num_batches, num_imgs, _, h, w = x.size() + + x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D + x_center = x[:, :, num_imgs // 2, :, :] + + x = self.conv3d1(x) + x = self.dense_block1(x) + x = self.dense_block2(x) + x = F.relu(self.bn3d2(x), inplace=True) + x = F.relu(self.conv3d2(x), inplace=True) + + # residual image + res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) + + # filter + filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) + filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1) + + # dynamic filter + out = self.dynamic_filter(x_center, filter_) + out += res.squeeze_(2) + out = F.pixel_shuffle(out, self.scale) + + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..199d786fddfdd160ef524fb161bcbcb4668a38e6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY + + +class SeqConv3x3(nn.Module): + """The re-parameterizable block used in the ECBSR architecture. + + Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices + Ref git repo: https://github.com/xindongzhang/ECBSR + + Args: + seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian. + in_channels (int): Channel number of input. + out_channels (int): Channel number of output. + depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1. + """ + + def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1): + super(SeqConv3x3, self).__init__() + self.seq_type = seq_type + self.in_channels = in_channels + self.out_channels = out_channels + + if self.seq_type == 'conv1x1-conv3x3': + self.mid_planes = int(out_channels * depth_multiplier) + conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3) + self.k1 = conv1.weight + self.b1 = conv1.bias + + elif self.seq_type == 'conv1x1-sobelx': + conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale and bias + scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + bias = torch.randn(self.out_channels) * 1e-3 + bias = torch.reshape(bias, (self.out_channels, )) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_channels): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 1, 0] = 2.0 + self.mask[i, 0, 2, 0] = 1.0 + self.mask[i, 0, 0, 2] = -1.0 + self.mask[i, 0, 1, 2] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.seq_type == 'conv1x1-sobely': + conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale and bias + scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + bias = torch.randn(self.out_channels) * 1e-3 + bias = torch.reshape(bias, (self.out_channels, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_channels): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 0, 1] = 2.0 + self.mask[i, 0, 0, 2] = 1.0 + self.mask[i, 0, 2, 0] = -1.0 + self.mask[i, 0, 2, 1] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.seq_type == 'conv1x1-laplacian': + conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale and bias + scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + bias = torch.randn(self.out_channels) * 1e-3 + bias = torch.reshape(bias, (self.out_channels, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_channels): + self.mask[i, 0, 0, 1] = 1.0 + self.mask[i, 0, 1, 0] = 1.0 + self.mask[i, 0, 1, 2] = 1.0 + self.mask[i, 0, 2, 1] = 1.0 + self.mask[i, 0, 1, 1] = -4.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + else: + raise ValueError('The type of seqconv is not supported!') + + def forward(self, x): + if self.seq_type == 'conv1x1-conv3x3': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + else: + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels) + return y1 + + def rep_params(self): + device = self.k0.get_device() + if device < 0: + device = None + + if self.seq_type == 'conv1x1-conv3x3': + # re-param conv kernel + rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1 + else: + tmp = self.scale * self.mask + k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device) + for i in range(self.out_channels): + k1[i, i, :, :] = tmp[i, 0, :, :] + b1 = self.bias + # re-param conv kernel + rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1 + return rep_weight, rep_bias + + +class ECB(nn.Module): + """The ECB block used in the ECBSR architecture. + + Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices + Ref git repo: https://github.com/xindongzhang/ECBSR + + Args: + in_channels (int): Channel number of input. + out_channels (int): Channel number of output. + depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1. + act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu. + with_idt (bool): Whether to use identity connection. Default: False. + """ + + def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False): + super(ECB, self).__init__() + + self.depth_multiplier = depth_multiplier + self.in_channels = in_channels + self.out_channels = out_channels + self.act_type = act_type + + if with_idt and (self.in_channels == self.out_channels): + self.with_idt = True + else: + self.with_idt = False + + self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1) + self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier) + self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels) + self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels) + self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels) + + if self.act_type == 'prelu': + self.act = nn.PReLU(num_parameters=self.out_channels) + elif self.act_type == 'relu': + self.act = nn.ReLU(inplace=True) + elif self.act_type == 'rrelu': + self.act = nn.RReLU(lower=-0.05, upper=0.05) + elif self.act_type == 'softplus': + self.act = nn.Softplus() + elif self.act_type == 'linear': + pass + else: + raise ValueError('The type of activation if not support!') + + def forward(self, x): + if self.training: + y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x) + if self.with_idt: + y += x + else: + rep_weight, rep_bias = self.rep_params() + y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1) + if self.act_type != 'linear': + y = self.act(y) + return y + + def rep_params(self): + weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias + weight1, bias1 = self.conv1x1_3x3.rep_params() + weight2, bias2 = self.conv1x1_sbx.rep_params() + weight3, bias3 = self.conv1x1_sby.rep_params() + weight4, bias4 = self.conv1x1_lpl.rep_params() + rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), ( + bias0 + bias1 + bias2 + bias3 + bias4) + + if self.with_idt: + device = rep_weight.get_device() + if device < 0: + device = None + weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device) + for i in range(self.out_channels): + weight_idt[i, i, 1, 1] = 1.0 + bias_idt = 0.0 + rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt + return rep_weight, rep_bias + + +@ARCH_REGISTRY.register() +class ECBSR(nn.Module): + """ECBSR architecture. + + Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices + Ref git repo: https://github.com/xindongzhang/ECBSR + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_block (int): Block number in the trunk network. + num_channel (int): Channel number. + with_idt (bool): Whether use identity in convolution layers. + act_type (str): Activation type. + scale (int): Upsampling factor. + """ + + def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale): + super(ECBSR, self).__init__() + self.num_in_ch = num_in_ch + self.scale = scale + + backbone = [] + backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] + for _ in range(num_block): + backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] + backbone += [ + ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt) + ] + + self.backbone = nn.Sequential(*backbone) + self.upsampler = nn.PixelShuffle(scale) + + def forward(self, x): + if self.num_in_ch > 1: + shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1) + else: + shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times) + y = self.backbone(x) + shortcut + y = self.upsampler(y) + return y diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..78ef9c4eb7b0d1c78ac6f56403b57f3862a32902 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py @@ -0,0 +1,61 @@ +import torch +from torch import nn as nn + +from r_basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer +from r_basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class EDSR(nn.Module): + """EDSR network structure. + + Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. + Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64. + num_block (int): Block number in the trunk network. Default: 16. + upscale (int): Upsampling factor. Support 2^n and 3. + Default: 4. + res_scale (float): Used to scale the residual in residual block. + Default: 1. + img_range (float): Image range. Default: 255. + rgb_mean (tuple[float]): Image mean in RGB orders. + Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. + """ + + def __init__(self, + num_in_ch, + num_out_ch, + num_feat=64, + num_block=16, + upscale=4, + res_scale=1, + img_range=255., + rgb_mean=(0.4488, 0.4371, 0.4040)): + super(EDSR, self).__init__() + + self.img_range = img_range + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) + self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + self.mean = self.mean.type_as(x) + + x = (x - self.mean) * self.img_range + x = self.conv_first(x) + res = self.conv_after_body(self.body(x)) + res += x + + x = self.conv_last(self.upsample(res)) + x = x / self.img_range + self.mean + + return x diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5e0388d01f1552592d0dc50854d38f5aad55ad --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py @@ -0,0 +1,383 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer + + +class PCDAlignment(nn.Module): + """Alignment module using Pyramid, Cascading and Deformable convolution + (PCD). It is used in EDVR. + + Ref: + EDVR: Video Restoration with Enhanced Deformable Convolutional Networks + + Args: + num_feat (int): Channel number of middle features. Default: 64. + deformable_groups (int): Deformable groups. Defaults: 8. + """ + + def __init__(self, num_feat=64, deformable_groups=8): + super(PCDAlignment, self).__init__() + + # Pyramid has three levels: + # L3: level 3, 1/4 spatial size + # L2: level 2, 1/2 spatial size + # L1: level 1, original spatial size + self.offset_conv1 = nn.ModuleDict() + self.offset_conv2 = nn.ModuleDict() + self.offset_conv3 = nn.ModuleDict() + self.dcn_pack = nn.ModuleDict() + self.feat_conv = nn.ModuleDict() + + # Pyramids + for i in range(3, 0, -1): + level = f'l{i}' + self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + if i == 3: + self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + else: + self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) + + if i < 3: + self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + + # Cascading dcn + self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, nbr_feat_l, ref_feat_l): + """Align neighboring frame features to the reference frame features. + + Args: + nbr_feat_l (list[Tensor]): Neighboring feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + ref_feat_l (list[Tensor]): Reference feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + + Returns: + Tensor: Aligned features. + """ + # Pyramids + upsampled_offset, upsampled_feat = None, None + for i in range(3, 0, -1): + level = f'l{i}' + offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 3: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1))) + offset = self.lrelu(self.offset_conv3[level](offset)) + + feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset) + if i < 3: + feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1)) + if i > 1: + feat = self.lrelu(feat) + + if i > 1: # upsample offset and features + # x2: when we upsample the offset, we should also enlarge + # the magnitude. + upsampled_offset = self.upsample(offset) * 2 + upsampled_feat = self.upsample(feat) + + # Cascading + offset = torch.cat([feat, ref_feat_l[0]], dim=1) + offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset)))) + feat = self.lrelu(self.cas_dcnpack(feat, offset)) + return feat + + +class TSAFusion(nn.Module): + """Temporal Spatial Attention (TSA) fusion module. + + Temporal: Calculate the correlation between center frame and + neighboring frames; + Spatial: It has 3 pyramid levels, the attention is similar to SFT. + (SFT: Recovering realistic texture in image super-resolution by deep + spatial feature transform.) + + Args: + num_feat (int): Channel number of middle features. Default: 64. + num_frame (int): Number of frames. Default: 5. + center_frame_idx (int): The index of center frame. Default: 2. + """ + + def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2): + super(TSAFusion, self).__init__() + self.center_frame_idx = center_frame_idx + # temporal attention (before fusion conv) + self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) + + # spatial attention (after fusion conv) + self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) + self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1) + self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1) + self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1) + self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, aligned_feat): + """ + Args: + aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w). + + Returns: + Tensor: Features after TSA with the shape (b, c, h, w). + """ + b, t, c, h, w = aligned_feat.size() + # temporal attention + embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone()) + embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w)) + embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w) + + corr_l = [] # correlation list + for i in range(t): + emb_neighbor = embedding[:, i, :, :, :] + corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w) + corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w) + corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w) + corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w) + corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w) + aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob + + # fusion + feat = self.lrelu(self.feat_fusion(aligned_feat)) + + # spatial attention + attn = self.lrelu(self.spatial_attn1(aligned_feat)) + attn_max = self.max_pool(attn) + attn_avg = self.avg_pool(attn) + attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1))) + # pyramid levels + attn_level = self.lrelu(self.spatial_attn_l1(attn)) + attn_max = self.max_pool(attn_level) + attn_avg = self.avg_pool(attn_level) + attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1))) + attn_level = self.lrelu(self.spatial_attn_l3(attn_level)) + attn_level = self.upsample(attn_level) + + attn = self.lrelu(self.spatial_attn3(attn)) + attn_level + attn = self.lrelu(self.spatial_attn4(attn)) + attn = self.upsample(attn) + attn = self.spatial_attn5(attn) + attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn))) + attn = torch.sigmoid(attn) + + # after initialization, * 2 makes (attn * 2) to be close to 1. + feat = feat * attn * 2 + attn_add + return feat + + +class PredeblurModule(nn.Module): + """Pre-dublur module. + + Args: + num_in_ch (int): Channel number of input image. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + hr_in (bool): Whether the input has high resolution. Default: False. + """ + + def __init__(self, num_in_ch=3, num_feat=64, hr_in=False): + super(PredeblurModule, self).__init__() + self.hr_in = hr_in + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + if self.hr_in: + # downsample x4 by stride conv + self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + + # generate feature pyramid + self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + + self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)]) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + feat_l1 = self.lrelu(self.conv_first(x)) + if self.hr_in: + feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1)) + feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1)) + + # generate feature pyramid + feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1)) + feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2)) + + feat_l3 = self.upsample(self.resblock_l3(feat_l3)) + feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3 + feat_l2 = self.upsample(self.resblock_l2_2(feat_l2)) + + for i in range(2): + feat_l1 = self.resblock_l1[i](feat_l1) + feat_l1 = feat_l1 + feat_l2 + for i in range(2, 5): + feat_l1 = self.resblock_l1[i](feat_l1) + return feat_l1 + + +@ARCH_REGISTRY.register() +class EDVR(nn.Module): + """EDVR network structure for video super-resolution. + + Now only support X4 upsampling factor. + Paper: + EDVR: Video Restoration with Enhanced Deformable Convolutional Networks + + Args: + num_in_ch (int): Channel number of input image. Default: 3. + num_out_ch (int): Channel number of output image. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_frame (int): Number of input frames. Default: 5. + deformable_groups (int): Deformable groups. Defaults: 8. + num_extract_block (int): Number of blocks for feature extraction. + Default: 5. + num_reconstruct_block (int): Number of blocks for reconstruction. + Default: 10. + center_frame_idx (int): The index of center frame. Frame counting from + 0. Default: Middle of input frames. + hr_in (bool): Whether the input has high resolution. Default: False. + with_predeblur (bool): Whether has predeblur module. + Default: False. + with_tsa (bool): Whether has TSA module. Default: True. + """ + + def __init__(self, + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_frame=5, + deformable_groups=8, + num_extract_block=5, + num_reconstruct_block=10, + center_frame_idx=None, + hr_in=False, + with_predeblur=False, + with_tsa=True): + super(EDVR, self).__init__() + if center_frame_idx is None: + self.center_frame_idx = num_frame // 2 + else: + self.center_frame_idx = center_frame_idx + self.hr_in = hr_in + self.with_predeblur = with_predeblur + self.with_tsa = with_tsa + + # extract features for each frame + if self.with_predeblur: + self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in) + self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1) + else: + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + + # extract pyramid features + self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat) + self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + # pcd and tsa module + self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups) + if self.with_tsa: + self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx) + else: + self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) + + # reconstruction + self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat) + # upsample + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(2) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + b, t, c, h, w = x.size() + if self.hr_in: + assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.') + else: + assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.') + + x_center = x[:, self.center_frame_idx, :, :, :].contiguous() + + # extract features for each frame + # L1 + if self.with_predeblur: + feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w))) + if self.hr_in: + h, w = h // 4, w // 4 + else: + feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) + + feat_l1 = self.feature_extraction(feat_l1) + # L2 + feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) + feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) + # L3 + feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) + feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) + + feat_l1 = feat_l1.view(b, t, -1, h, w) + feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2) + feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4) + + # PCD alignment + ref_feat_l = [ # reference feature list + feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), + feat_l3[:, self.center_frame_idx, :, :, :].clone() + ] + aligned_feat = [] + for i in range(t): + nbr_feat_l = [ # neighboring feature list + feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() + ] + aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) + aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) + + if not self.with_tsa: + aligned_feat = aligned_feat.view(b, -1, h, w) + feat = self.fusion(aligned_feat) + + out = self.reconstruction(feat) + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + if self.hr_in: + base = x_center + else: + base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) + out += base + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..507ed3c999e4c4954079f1eae7d402587827f621 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py @@ -0,0 +1,259 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer + + +class SPADEGenerator(BaseNetwork): + """Generator with SPADEResBlock""" + + def __init__(self, + num_in_ch=3, + num_feat=64, + use_vae=False, + z_dim=256, + crop_size=512, + norm_g='spectralspadesyncbatch3x3', + is_train=True, + init_train_phase=3): # progressive training disabled + super().__init__() + self.nf = num_feat + self.input_nc = num_in_ch + self.is_train = is_train + self.train_phase = init_train_phase + + self.scale_ratio = 5 # hardcoded now + self.sw = crop_size // (2**self.scale_ratio) + self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0 + + if use_vae: + # In case of VAE, we will sample from random z vector + self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh) + else: + # Otherwise, we make the network deterministic by starting with + # downsampled segmentation map instead of random z + self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1) + + self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) + + self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) + self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) + + self.ups = nn.ModuleList([ + SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g), + SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g), + SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g), + SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g) + ]) + + self.to_rgbs = nn.ModuleList([ + nn.Conv2d(8 * self.nf, 3, 3, padding=1), + nn.Conv2d(4 * self.nf, 3, 3, padding=1), + nn.Conv2d(2 * self.nf, 3, 3, padding=1), + nn.Conv2d(1 * self.nf, 3, 3, padding=1) + ]) + + self.up = nn.Upsample(scale_factor=2) + + def encode(self, input_tensor): + """ + Encode input_tensor into feature maps, can be overridden in derived classes + Default: nearest downsampling of 2**5 = 32 times + """ + h, w = input_tensor.size()[-2:] + sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio + x = F.interpolate(input_tensor, size=(sh, sw)) + return self.fc(x) + + def forward(self, x): + # In oroginal SPADE, seg means a segmentation map, but here we use x instead. + seg = x + + x = self.encode(x) + x = self.head_0(x, seg) + + x = self.up(x) + x = self.g_middle_0(x, seg) + x = self.g_middle_1(x, seg) + + if self.is_train: + phase = self.train_phase + 1 + else: + phase = len(self.to_rgbs) + + for i in range(phase): + x = self.up(x) + x = self.ups[i](x, seg) + + x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1)) + x = torch.tanh(x) + + return x + + def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'): + """ + A helper class for subspace visualization. Input and seg are different images. + For the first n levels (including encoder) we use input, for the rest we use seg. + + If mode = 'progressive', the output's like: AAABBB + If mode = 'one_plug', the output's like: AAABAA + If mode = 'one_ablate', the output's like: BBBABB + """ + + if seg is None: + return self.forward(input_x) + + if self.is_train: + phase = self.train_phase + 1 + else: + phase = len(self.to_rgbs) + + if mode == 'progressive': + n = max(min(n, 4 + phase), 0) + guide_list = [input_x] * n + [seg] * (4 + phase - n) + elif mode == 'one_plug': + n = max(min(n, 4 + phase - 1), 0) + guide_list = [seg] * (4 + phase) + guide_list[n] = input_x + elif mode == 'one_ablate': + if n > 3 + phase: + return self.forward(input_x) + guide_list = [input_x] * (4 + phase) + guide_list[n] = seg + + x = self.encode(guide_list[0]) + x = self.head_0(x, guide_list[1]) + + x = self.up(x) + x = self.g_middle_0(x, guide_list[2]) + x = self.g_middle_1(x, guide_list[3]) + + for i in range(phase): + x = self.up(x) + x = self.ups[i](x, guide_list[4 + i]) + + x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1)) + x = torch.tanh(x) + + return x + + +@ARCH_REGISTRY.register() +class HiFaceGAN(SPADEGenerator): + """ + HiFaceGAN: SPADEGenerator with a learnable feature encoder + Current encoder design: LIPEncoder + """ + + def __init__(self, + num_in_ch=3, + num_feat=64, + use_vae=False, + z_dim=256, + crop_size=512, + norm_g='spectralspadesyncbatch3x3', + is_train=True, + init_train_phase=3): + super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase) + self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio) + + def encode(self, input_tensor): + return self.lip_encoder(input_tensor) + + +@ARCH_REGISTRY.register() +class HiFaceGANDiscriminator(BaseNetwork): + """ + Inspired by pix2pixHD multiscale discriminator. + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + conditional_d (bool): Whether use conditional discriminator. + Default: True. + num_d (int): Number of Multiscale discriminators. Default: 3. + n_layers_d (int): Number of downsample layers in each D. Default: 4. + num_feat (int): Channel number of base intermediate features. + Default: 64. + norm_d (str): String to determine normalization layers in D. + Choices: [spectral][instance/batch/syncbatch] + Default: 'spectralinstance'. + keep_features (bool): Keep intermediate features for matching loss, etc. + Default: True. + """ + + def __init__(self, + num_in_ch=3, + num_out_ch=3, + conditional_d=True, + num_d=2, + n_layers_d=4, + num_feat=64, + norm_d='spectralinstance', + keep_features=True): + super().__init__() + self.num_d = num_d + + input_nc = num_in_ch + if conditional_d: + input_nc += num_out_ch + + for i in range(num_d): + subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features) + self.add_module(f'discriminator_{i}', subnet_d) + + def downsample(self, x): + return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) + + # Returns list of lists of discriminator outputs. + # The final result is of size opt.num_d x opt.n_layers_D + def forward(self, x): + result = [] + for _, _net_d in self.named_children(): + out = _net_d(x) + result.append(out) + x = self.downsample(x) + + return result + + +class NLayerDiscriminator(BaseNetwork): + """Defines the PatchGAN discriminator with the specified arguments.""" + + def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features): + super().__init__() + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + nf = num_feat + self.keep_features = keep_features + + norm_layer = get_nonspade_norm_layer(norm_d) + sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]] + + for n in range(1, n_layers_d): + nf_prev = nf + nf = min(nf * 2, 512) + stride = 1 if n == n_layers_d - 1 else 2 + sequence += [[ + norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)), + nn.LeakyReLU(0.2, False) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + # We divide the layers into groups to extract intermediate layer outputs + for n in range(len(sequence)): + self.add_module('model' + str(n), nn.Sequential(*sequence[n])) + + def forward(self, x): + results = [x] + for submodel in self.children(): + intermediate_output = submodel(results[-1]) + results.append(intermediate_output) + + if self.keep_features: + return results[1:] + else: + return results[-1] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py new file mode 100644 index 0000000000000000000000000000000000000000..35cbef3f532fcc6aab0fa57ab316a546d3a17bd5 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py @@ -0,0 +1,255 @@ +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +# Warning: spectral norm could be buggy +# under eval mode and multi-GPU inference +# A workaround is sticking to single-GPU inference and train mode +from torch.nn.utils import spectral_norm + + +class SPADE(nn.Module): + + def __init__(self, config_text, norm_nc, label_nc): + super().__init__() + + assert config_text.startswith('spade') + parsed = re.search('spade(\\D+)(\\d)x\\d', config_text) + param_free_norm_type = str(parsed.group(1)) + ks = int(parsed.group(2)) + + if param_free_norm_type == 'instance': + self.param_free_norm = nn.InstanceNorm2d(norm_nc) + elif param_free_norm_type == 'syncbatch': + print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') + self.param_free_norm = nn.InstanceNorm2d(norm_nc) + elif param_free_norm_type == 'batch': + self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) + else: + raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE') + + # The dimension of the intermediate embedding space. Yes, hardcoded. + nhidden = 128 if norm_nc > 128 else norm_nc + + pw = ks // 2 + self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) + + def forward(self, x, segmap): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + out = normalized * gamma + beta + + return out + + +class SPADEResnetBlock(nn.Module): + """ + ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that + it takes in the segmentation map as input, learns the skip connection if necessary, + and applies normalization first and then convolution. + This architecture seemed like a standard architecture for unconditional or + class-conditional GAN architecture using residual block. + The code was inspired from https://github.com/LMescheder/GAN_stability. + """ + + def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + # apply spectral norm if specified + if 'spectral' in norm_g: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + + # define normalization layers + spade_config_str = norm_g.replace('spectral', '') + self.norm_0 = SPADE(spade_config_str, fin, semantic_nc) + self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc) + if self.learned_shortcut: + self.norm_s = SPADE(spade_config_str, fin, semantic_nc) + + # note the resnet block with SPADE also takes in |seg|, + # the semantic segmentation map as input + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.act(self.norm_0(x, seg))) + dx = self.conv_1(self.act(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + def act(self, x): + return F.leaky_relu(x, 2e-1) + + +class BaseNetwork(nn.Module): + """ A basis for hifacegan archs with custom initialization """ + + def init_weights(self, init_type='normal', gain=0.02): + + def init_func(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError(f'initialization method [{init_type}] is not implemented') + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + def forward(self, x): + pass + + +def lip2d(x, logit, kernel=3, stride=2, padding=1): + weight = logit.exp() + return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding) + + +class SoftGate(nn.Module): + COEFF = 12.0 + + def forward(self, x): + return torch.sigmoid(x).mul(self.COEFF) + + +class SimplifiedLIP(nn.Module): + + def __init__(self, channels): + super(SimplifiedLIP, self).__init__() + self.logit = nn.Sequential( + nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True), + SoftGate()) + + def init_layer(self): + self.logit[0].weight.data.fill_(0.0) + + def forward(self, x): + frac = lip2d(x, self.logit(x)) + return frac + + +class LIPEncoder(BaseNetwork): + """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)""" + + def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d): + super().__init__() + self.sw = sw + self.sh = sh + self.max_ratio = 16 + # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold + kw = 3 + pw = (kw - 1) // 2 + + model = [ + nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False), + norm_layer(ngf), + nn.ReLU(), + ] + cur_ratio = 1 + for i in range(n_2xdown): + next_ratio = min(cur_ratio * 2, self.max_ratio) + model += [ + SimplifiedLIP(ngf * cur_ratio), + nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw), + norm_layer(ngf * next_ratio), + ] + cur_ratio = next_ratio + if i < n_2xdown - 1: + model += [nn.ReLU(inplace=True)] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +def get_nonspade_norm_layer(norm_type='instance'): + # helper function to get # output channels of the previous layer + def get_out_channel(layer): + if hasattr(layer, 'out_channels'): + return getattr(layer, 'out_channels') + return layer.weight.size(0) + + # this function will be returned + def add_norm_layer(layer): + nonlocal norm_type + if norm_type.startswith('spectral'): + layer = spectral_norm(layer) + subnorm_type = norm_type[len('spectral'):] + + if subnorm_type == 'none' or len(subnorm_type) == 0: + return layer + + # remove bias in the previous layer, which is meaningless + # since it has no effect after normalization + if getattr(layer, 'bias', None) is not None: + delattr(layer, 'bias') + layer.register_parameter('bias', None) + + if subnorm_type == 'batch': + norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == 'sync_batch': + print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') + # norm_layer = SynchronizedBatchNorm2d( + # get_out_channel(layer), affine=True) + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + elif subnorm_type == 'instance': + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + else: + raise ValueError(f'normalization layer {subnorm_type} is not recognized') + + return nn.Sequential(layer, norm_layer) + + print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.') + return add_norm_layer diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..de1abef67270dc1aba770943b53577029141f527 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py @@ -0,0 +1,307 @@ +# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501 +# For FID metric + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.model_zoo import load_url +from torchvision import models + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 +LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling features + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=(DEFAULT_BLOCK_INDEX), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3. + + Args: + output_blocks (list[int]): Indices of blocks to return features of. + Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input (bool): If true, bilinearly resizes input to width and + height 299 before feeding input to model. As the network + without fully connected layers is fully convolutional, it + should be able to handle inputs of arbitrary size, so resizing + might not be strictly needed. Default: True. + normalize_input (bool): If true, scales the input from range (0, 1) + to the range the pretrained Inception network expects, + namely (-1, 1). Default: True. + requires_grad (bool): If true, parameters of the model require + gradients. Possibly useful for finetuning the network. + Default: False. + use_fid_inception (bool): If true, uses the pretrained Inception + model used in Tensorflow's FID implementation. + If false, uses the pretrained Inception model available in + torchvision. The FID Inception model has different weights + and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get + comparable results. Default: True. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, ('Last possible output block index is 3') + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + try: + inception = models.inception_v3(pretrained=True, init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, x): + """Get Inception feature maps. + + Args: + x (Tensor): Input tensor of shape (b, 3, h, w). + Values are expected to be in range (-1, 1). You can also input + (0, 1) with setting normalize_input = True. + + Returns: + list[Tensor]: Corresponding to the selected output block, sorted + ascending by index. + """ + output = [] + + if self.resize_input: + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + output.append(x) + + if idx == self.last_needed_block: + break + + return output + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation. + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + try: + inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False) + + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + if os.path.exists(LOCAL_FID_WEIGHTS): + state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage) + else: + state_dict = load_url(FID_WEIGHTS_URL, progress=True) + + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1714361b083777156cefe188a047ef6819a3b940 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py @@ -0,0 +1,135 @@ +import torch +from torch import nn as nn + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import Upsample, make_layer + + +class ChannelAttention(nn.Module): + """Channel attention used in RCAN. + + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + """ + + def __init__(self, num_feat, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) + + def forward(self, x): + y = self.attention(x) + return x * y + + +class RCAB(nn.Module): + """Residual Channel Attention Block (RCAB) used in RCAN. + + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + res_scale (float): Scale the residual. Default: 1. + """ + + def __init__(self, num_feat, squeeze_factor=16, res_scale=1): + super(RCAB, self).__init__() + self.res_scale = res_scale + + self.rcab = nn.Sequential( + nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), + ChannelAttention(num_feat, squeeze_factor)) + + def forward(self, x): + res = self.rcab(x) * self.res_scale + return res + x + + +class ResidualGroup(nn.Module): + """Residual Group of RCAB. + + Args: + num_feat (int): Channel number of intermediate features. + num_block (int): Block number in the body network. + squeeze_factor (int): Channel squeeze factor. Default: 16. + res_scale (float): Scale the residual. Default: 1. + """ + + def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): + super(ResidualGroup, self).__init__() + + self.residual_group = make_layer( + RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) + self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + def forward(self, x): + res = self.conv(self.residual_group(x)) + return res + x + + +@ARCH_REGISTRY.register() +class RCAN(nn.Module): + """Residual Channel Attention Networks. + + Paper: Image Super-Resolution Using Very Deep Residual Channel Attention + Networks + Ref git repo: https://github.com/yulunzhang/RCAN. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64. + num_group (int): Number of ResidualGroup. Default: 10. + num_block (int): Number of RCAB in ResidualGroup. Default: 16. + squeeze_factor (int): Channel squeeze factor. Default: 16. + upscale (int): Upsampling factor. Support 2^n and 3. + Default: 4. + res_scale (float): Used to scale the residual in residual block. + Default: 1. + img_range (float): Image range. Default: 255. + rgb_mean (tuple[float]): Image mean in RGB orders. + Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. + """ + + def __init__(self, + num_in_ch, + num_out_ch, + num_feat=64, + num_group=10, + num_block=16, + squeeze_factor=16, + upscale=4, + res_scale=1, + img_range=255., + rgb_mean=(0.4488, 0.4371, 0.4040)): + super(RCAN, self).__init__() + + self.img_range = img_range + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer( + ResidualGroup, + num_group, + num_feat=num_feat, + num_block=num_block, + squeeze_factor=squeeze_factor, + res_scale=res_scale) + self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + self.mean = self.mean.type_as(x) + + x = (x - self.mean) * self.img_range + x = self.conv_first(x) + res = self.conv_after_body(self.body(x)) + res += x + + x = self.conv_last(self.upsample(res)) + x = x / self.img_range + self.mean + + return x diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..22a0ae25ed8f9b818f570c802a99b18ca9e96118 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import ResidualBlockNoBN, make_layer + + +class MeanShift(nn.Conv2d): + """ Data normalization with mean and std. + + Args: + rgb_range (int): Maximum value of RGB. + rgb_mean (list[float]): Mean for RGB channels. + rgb_std (list[float]): Std for RGB channels. + sign (int): For subtraction, sign is -1, for addition, sign is 1. + Default: -1. + requires_grad (bool): Whether to update the self.weight and self.bias. + Default: True. + """ + + def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) + self.weight.data.div_(std.view(3, 1, 1, 1)) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) + self.bias.data.div_(std) + self.requires_grad = requires_grad + + +class EResidualBlockNoBN(nn.Module): + """Enhanced Residual block without BN. + + There are three convolution layers in residual branch. + + It has a style of: + ---Conv-ReLU-Conv-ReLU-Conv-+-ReLU- + |__________________________| + """ + + def __init__(self, in_channels, out_channels): + super(EResidualBlockNoBN, self).__init__() + + self.body = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 1, 1, 0), + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.body(x) + out = self.relu(out + x) + return out + + +class MergeRun(nn.Module): + """ Merge-and-run unit. + + This unit contains two branches with different dilated convolutions, + followed by a convolution to process the concatenated features. + + Paper: Real Image Denoising with Feature Attention + Ref git repo: https://github.com/saeed-anwar/RIDNet + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): + super(MergeRun, self).__init__() + + self.dilation1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True)) + self.dilation2 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True)) + + self.aggregation = nn.Sequential( + nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True)) + + def forward(self, x): + dilation1 = self.dilation1(x) + dilation2 = self.dilation2(x) + out = torch.cat([dilation1, dilation2], dim=1) + out = self.aggregation(out) + out = out + x + return out + + +class ChannelAttention(nn.Module): + """Channel attention. + + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: + """ + + def __init__(self, mid_channels, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid()) + + def forward(self, x): + y = self.attention(x) + return x * y + + +class EAM(nn.Module): + """Enhancement attention modules (EAM) in RIDNet. + + This module contains a merge-and-run unit, a residual block, + an enhanced residual block and a feature attention unit. + + Attributes: + merge: The merge-and-run unit. + block1: The residual block. + block2: The enhanced residual block. + ca: The feature/channel attention unit. + """ + + def __init__(self, in_channels, mid_channels, out_channels): + super(EAM, self).__init__() + + self.merge = MergeRun(in_channels, mid_channels) + self.block1 = ResidualBlockNoBN(mid_channels) + self.block2 = EResidualBlockNoBN(mid_channels, out_channels) + self.ca = ChannelAttention(out_channels) + # The residual block in the paper contains a relu after addition. + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.merge(x) + out = self.relu(self.block1(out)) + out = self.block2(out) + out = self.ca(out) + return out + + +@ARCH_REGISTRY.register() +class RIDNet(nn.Module): + """RIDNet: Real Image Denoising with Feature Attention. + + Ref git repo: https://github.com/saeed-anwar/RIDNet + + Args: + in_channels (int): Channel number of inputs. + mid_channels (int): Channel number of EAM modules. + Default: 64. + out_channels (int): Channel number of outputs. + num_block (int): Number of EAM. Default: 4. + img_range (float): Image range. Default: 255. + rgb_mean (tuple[float]): Image mean in RGB orders. + Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. + """ + + def __init__(self, + in_channels, + mid_channels, + out_channels, + num_block=4, + img_range=255., + rgb_mean=(0.4488, 0.4371, 0.4040), + rgb_std=(1.0, 1.0, 1.0)): + super(RIDNet, self).__init__() + + self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std) + self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1) + + self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) + self.body = make_layer( + EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels) + self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + res = self.sub_mean(x) + res = self.tail(self.body(self.relu(self.head(res)))) + res = self.add_mean(res) + + out = x + res + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c4c4de08c25e87f4ab6ba0bf0eac5b6f003d35 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0756acf982292f077a20cf499ed19936919c1c --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py @@ -0,0 +1,96 @@ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import flow_warp + + +class BasicModule(nn.Module): + """Basic Module for SpyNet. + """ + + def __init__(self): + super(BasicModule, self).__init__() + + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + return self.basic_module(tensor_input) + + +@ARCH_REGISTRY.register() +class SpyNet(nn.Module): + """SpyNet architecture. + + Args: + load_path (str): path for pretrained SpyNet. Default: None. + """ + + def __init__(self, load_path=None): + super(SpyNet, self).__init__() + self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def preprocess(self, tensor_input): + tensor_output = (tensor_input - self.mean) / self.std + return tensor_output + + def process(self, ref, supp): + flow = [] + + ref = [self.preprocess(ref)] + supp = [self.preprocess(supp)] + + for level in range(5): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + flow = ref[0].new_zeros( + [ref[0].size(0), 2, + int(math.floor(ref[0].size(2) / 2.0)), + int(math.floor(ref[0].size(3) / 2.0))]) + + for level in range(len(ref)): + upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + + if upsampled_flow.size(2) != ref[level].size(2): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') + if upsampled_flow.size(3) != ref[level].size(3): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') + + flow = self.basic_module[level](torch.cat([ + ref[level], + flow_warp( + supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), + upsampled_flow + ], 1)) + upsampled_flow + + return flow + + def forward(self, ref, supp): + assert ref.size() == supp.size() + + h, w = ref.size(2), ref.size(3) + w_floor = math.floor(math.ceil(w / 32.0) * 32.0) + h_floor = math.floor(math.ceil(h / 32.0) * 32.0) + + ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + + flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) + + flow[:, 0, :, :] *= float(w) / float(w_floor) + flow[:, 1, :, :] *= float(h) / float(h_floor) + + return flow diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..f922ecc1a82428d0770d8f566e6c501b332be3e7 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py @@ -0,0 +1,65 @@ +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer + + +@ARCH_REGISTRY.register() +class MSRResNet(nn.Module): + """Modified SRResNet. + + A compacted version modified from SRResNet in + "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" + It uses residual blocks without BN, similar to EDSR. + Currently, it supports x2, x3 and x4 upsampling scale factor. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_block (int): Block number in the body network. Default: 16. + upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): + super(MSRResNet, self).__init__() + self.upscale = upscale + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) + + # upsampling + if self.upscale in [2, 3]: + self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(self.upscale) + elif self.upscale == 4: + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(2) + + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # initialization + default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) + if self.upscale == 4: + default_init_weights(self.upconv2, 0.1) + + def forward(self, x): + feat = self.lrelu(self.conv_first(x)) + out = self.body(feat) + + if self.upscale == 4: + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + elif self.upscale in [2, 3]: + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + + out = self.conv_last(self.lrelu(self.conv_hr(out))) + base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) + out += base + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..02267d80c6d4fbdcc2ed8ac96ef0e51ab21ac727 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py @@ -0,0 +1,70 @@ +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register(suffix='basicsr') +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..b9bbf4de7a74a1631bcefe6c721d14e37d15de11 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py @@ -0,0 +1,799 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from r_basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu +from r_basicsr.ops.upfirdn2d import upfirdn2d +from r_basicsr.utils.registry import ARCH_REGISTRY + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +def make_resample_kernel(k): + """Make resampling kernel for UpFirDn. + + Args: + k (list[int]): A list indicating the 1D resample kernel magnitude. + + Returns: + Tensor: 2D resampled kernel. + """ + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] # to 2D kernel, outer product + # normalize + k /= k.sum() + return k + + +class UpFirDnUpsample(nn.Module): + """Upsample, FIR filter, and downsample (upsampole version). + + References: + 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501 + 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501 + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + factor (int): Upsampling scale factor. Default: 2. + """ + + def __init__(self, resample_kernel, factor=2): + super(UpFirDnUpsample, self).__init__() + self.kernel = make_resample_kernel(resample_kernel) * (factor**2) + self.factor = factor + + pad = self.kernel.shape[0] - factor + self.pad = ((pad + 1) // 2 + factor - 1, pad // 2) + + def forward(self, x): + out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(factor={self.factor})') + + +class UpFirDnDownsample(nn.Module): + """Upsample, FIR filter, and downsample (downsampole version). + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + factor (int): Downsampling scale factor. Default: 2. + """ + + def __init__(self, resample_kernel, factor=2): + super(UpFirDnDownsample, self).__init__() + self.kernel = make_resample_kernel(resample_kernel) + self.factor = factor + + pad = self.kernel.shape[0] - factor + self.pad = ((pad + 1) // 2, pad // 2) + + def forward(self, x): + out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(factor={self.factor})') + + +class UpFirDnSmooth(nn.Module): + """Upsample, FIR filter, and downsample (smooth version). + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + upsample_factor (int): Upsampling scale factor. Default: 1. + downsample_factor (int): Downsampling scale factor. Default: 1. + kernel_size (int): Kernel size: Default: 1. + """ + + def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1): + super(UpFirDnSmooth, self).__init__() + self.upsample_factor = upsample_factor + self.downsample_factor = downsample_factor + self.kernel = make_resample_kernel(resample_kernel) + if upsample_factor > 1: + self.kernel = self.kernel * (upsample_factor**2) + + if upsample_factor > 1: + pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1) + self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1) + elif downsample_factor > 1: + pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1) + self.pad = ((pad + 1) // 2, pad // 2) + else: + raise NotImplementedError + + def forward(self, x): + out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}' + f', downsample_factor={self.downsample_factor})') + + +class EqualLinear(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Size of each sample. + out_channels (int): Size of each output sample. + bias (bool): If set to ``False``, the layer will not learn an additive + bias. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + lr_mul (float): Learning rate multiplier. Default: 1. + activation (None | str): The activation after ``linear`` operation. + Supported: 'fused_lrelu', None. Default: None. + """ + + def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): + super(EqualLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lr_mul = lr_mul + self.activation = activation + if self.activation not in ['fused_lrelu', None]: + raise ValueError(f'Wrong activation value in EqualLinear: {activation}' + "Supported ones are: ['fused_lrelu', None].") + self.scale = (1 / math.sqrt(in_channels)) * lr_mul + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + if self.bias is None: + bias = None + else: + bias = self.bias * self.lr_mul + if self.activation == 'fused_lrelu': + out = F.linear(x, self.weight * self.scale) + out = fused_leaky_relu(out, bias) + else: + out = F.linear(x, self.weight * self.scale, bias=bias) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, bias={self.bias is not None})') + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=(1, 3, 3, 1), + eps=1e-8): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + if self.sample_mode == 'upsample': + self.smooth = UpFirDnSmooth( + resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size) + elif self.sample_mode == 'downsample': + self.smooth = UpFirDnSmooth( + resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size) + elif self.sample_mode is None: + pass + else: + raise ValueError(f'Wrong sample mode {self.sample_mode}, ' + "supported ones are ['upsample', 'downsample', None].") + + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + # modulation inside each modulated conv + self.modulation = EqualLinear( + num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) + + self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.scale * self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + if self.sample_mode == 'upsample': + x = x.view(1, b * c, h, w) + weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size) + out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + out = self.smooth(out) + elif self.sample_mode == 'downsample': + x = self.smooth(x) + x = x.view(1, b * c, *x.shape[2:4]) + out = F.conv2d(x, weight, padding=0, stride=2, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + else: + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, ' + f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=(1, 3, 3, 1)): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + resample_kernel=resample_kernel) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # activation (with bias) + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + """ + + def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)): + super(ToRGB, self).__init__() + if upsample: + self.upsample = UpFirDnUpsample(resample_kernel, factor=2) + else: + self.upsample = None + self.modulated_conv = ModulatedConv2d( + in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = self.upsample(skip) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2Generator(nn.Module): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kernel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1): + super(StyleGAN2Generator, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear( + num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=resample_kernel) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + resample_kernel=resample_kernel, + )) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=resample_kernel)) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ScaledLeakyReLU(nn.Module): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + out = F.conv2d( + x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size},' + f' stride={self.stride}, padding={self.padding}, ' + f'bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + resample_kernel (list[int]): A list indicating the 1D resample + kernel magnitude. A cross production will be applied to + extent 1D resample kernel to 2D resample kernel. + Default: (1, 3, 3, 1). + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True): + layers = [] + # downsample + if downsample: + layers.append( + UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)) + stride = 2 + self.padding = 0 + else: + stride = 1 + self.padding = kernel_size // 2 + # conv + layers.append( + EqualConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias + and not activate)) + # activation + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Module): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + resample_kernel (list[int]): A list indicating the 1D resample + kernel magnitude. A cross production will be applied to + extent 1D resample kernel to 2D resample kernel. + Default: (1, 3, 3, 1). + """ + + def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)): + super(ResBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvLayer( + in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True) + self.skip = ConvLayer( + in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2Discriminator(nn.Module): + """StyleGAN2 Discriminator. + + Args: + out_size (int): The spatial size of outputs. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kernel to 2D resample kernel. Default: (1, 3, 3, 1). + stddev_group (int): For group stddev statistics. Default: 4. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1): + super(StyleGAN2Discriminator, self).__init__() + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + + log_size = int(math.log(out_size, 2)) + + conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)] + + in_channels = channels[f'{out_size}'] + for i in range(log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + conv_body.append(ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + self.conv_body = nn.Sequential(*conv_body) + + self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True) + self.final_linear = nn.Sequential( + EqualLinear( + channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'), + EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None), + ) + self.stddev_group = stddev_group + self.stddev_feat = 1 + + def forward(self, x): + out = self.conv_body(x) + + b, c, h, w = out.shape + # concatenate a group stddev statistics to out + group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size + stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, h, w) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + out = out.view(b, -1) + out = self.final_linear(out) + + return out diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..a8468cf6252138fefdfcc3089c2a277f9162ab45 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py @@ -0,0 +1,956 @@ +# Modified from https://github.com/JingyunLiang/SwinIR +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. + +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import to_2tuple, trunc_normal_ + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (b, h, w, c) + window_size (int): window size + + Returns: + windows: (num_windows*b, window_size, window_size, c) + """ + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows + + +def window_reverse(windows, window_size, h, w): + """ + Args: + windows: (num_windows*b, window_size, window_size, c) + window_size (int): Window size + h (int): Height of image + w (int): Width of image + + Returns: + x: (b, h, w, c) + """ + b = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, n): + # calculate flops for 1 window with token length of n + flops = 0 + # qkv = self.qkv(x) + flops += n * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * n * (self.dim // self.num_heads) * n + # x = (attn @ v) + flops += self.num_heads * n * n * (self.dim // self.num_heads) + # x = self.proj(x) + flops += n * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + h, w = x_size + img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1 + h_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' + f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}') + + def flops(self): + flops = 0 + h, w = self.input_resolution + # norm1 + flops += self.dim * h * w + # W-MSA/SW-MSA + nw = h * w / self.window_size / self.window_size + flops += nw * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * h * w + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: b, h*w, c + """ + h, w = self.input_resolution + b, seq_len, c = x.shape + assert seq_len == h * w, 'input feature has wrong size' + assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.' + + x = x.view(b, h, w, c) + + x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c + x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c + x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c + x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c + x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c + x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f'input_resolution={self.input_resolution}, dim={self.dim}' + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.dim + flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + h, w = self.input_resolution + flops += h * w * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # b Ph*Pw c + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + h, w = self.img_size + if self.norm is not None: + flops += h * w * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.num_feat * 3 * 9 + return flops + + +@ARCH_REGISTRY.register() +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1., + upsampler='', + resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # b seq_len c + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x + + def flops(self): + flops = 0 + h, w = self.patches_resolution + flops += h * w * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for layer in self.layers: + flops += layer.flops() + flops += h * w * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR( + upscale=2, + img_size=(height, width), + window_size=window_size, + img_range=1., + depths=[6, 6, 6, 6], + embed_dim=60, + num_heads=[6, 6, 6, 6], + mlp_ratio=2, + upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..81ce55a41a3069db234177f48711fead76b9fe09 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py @@ -0,0 +1,172 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import flow_warp + + +class BasicModule(nn.Module): + """Basic module of SPyNet. + + Note that unlike the architecture in spynet_arch.py, the basic module + here contains batch normalization. + """ + + def __init__(self): + super(BasicModule, self).__init__() + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(64), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(16), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + + Returns: + Tensor: Estimated flow with shape (b, 2, h, w) + """ + return self.basic_module(tensor_input) + + +class SPyNetTOF(nn.Module): + """SPyNet architecture for TOF. + + Note that this implementation is specifically for TOFlow. Please use + spynet_arch.py for general use. They differ in the following aspects: + 1. The basic modules here contain BatchNorm. + 2. Normalization and denormalization are not done here, as + they are done in TOFlow. + Paper: + Optical Flow Estimation using a Spatial Pyramid Network + Code reference: + https://github.com/Coldog2333/pytoflow + + Args: + load_path (str): Path for pretrained SPyNet. Default: None. + """ + + def __init__(self, load_path=None): + super(SPyNetTOF, self).__init__() + + self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)]) + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + def forward(self, ref, supp): + """ + Args: + ref (Tensor): Reference image with shape of (b, 3, h, w). + supp: The supporting image to be warped: (b, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (b, 2, h, w). + """ + num_batches, _, h, w = ref.size() + ref = [ref] + supp = [supp] + + # generate downsampled frames + for _ in range(3): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + # flow computation + flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16) + for i in range(4): + flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module[i]( + torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) + return flow + + +@ARCH_REGISTRY.register() +class TOFlow(nn.Module): + """PyTorch implementation of TOFlow. + + In TOFlow, the LR frames are pre-upsampled and have the same size with + the GT frames. + Paper: + Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 + Code reference: + 1. https://github.com/anchen1011/toflow + 2. https://github.com/Coldog2333/pytoflow + + Args: + adapt_official_weights (bool): Whether to adapt the weights translated + from the official implementation. Set to false if you want to + train from scratch. Default: False + """ + + def __init__(self, adapt_official_weights=False): + super(TOFlow, self).__init__() + self.adapt_official_weights = adapt_official_weights + self.ref_idx = 0 if adapt_official_weights else 3 + + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + # flow estimation module + self.spynet = SPyNetTOF() + + # reconstruction module + self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4) + self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4) + self.conv_3 = nn.Conv2d(64, 64, 1) + self.conv_4 = nn.Conv2d(64, 3, 1) + + # activation function + self.relu = nn.ReLU(inplace=True) + + def normalize(self, img): + return (img - self.mean) / self.std + + def denormalize(self, img): + return img * self.std + self.mean + + def forward(self, lrs): + """ + Args: + lrs: Input lr frames: (b, 7, 3, h, w). + + Returns: + Tensor: SR frame: (b, 3, h, w). + """ + # In the official implementation, the 0-th frame is the reference frame + if self.adapt_official_weights: + lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] + + num_batches, num_lrs, _, h, w = lrs.size() + + lrs = self.normalize(lrs.view(-1, 3, h, w)) + lrs = lrs.view(num_batches, num_lrs, 3, h, w) + + lr_ref = lrs[:, self.ref_idx, :, :, :] + lr_aligned = [] + for i in range(7): # 7 frames + if i == self.ref_idx: + lr_aligned.append(lr_ref) + else: + lr_supp = lrs[:, i, :, :, :] + flow = self.spynet(lr_ref, lr_supp) + lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) + + # reconstruction + hr = torch.stack(lr_aligned, dim=1) + hr = hr.view(num_batches, -1, h, w) + hr = self.relu(self.conv_1(hr)) + hr = self.relu(self.conv_2(hr)) + hr = self.relu(self.conv_3(hr)) + hr = self.conv_4(hr) + lr_ref + + return self.denormalize(hr) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..517b79ae094d63648707c9700d5cbfb6c520b6fb --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from r_basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b53e5eedf2b839738944166d88f9f8a95177bf2e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py @@ -0,0 +1,101 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from r_basicsr.data.prefetch_dataloader import PrefetchDataLoader +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.dist_util import get_dist_info +from r_basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'r_basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must contain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..501f14c8cc8a47e768b4531e44e5d3e39d271e0d --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py @@ -0,0 +1,313 @@ +import cv2 +import numpy as np +import torch +from os import path as osp +from torch.nn import functional as F + +from r_basicsr.data.transforms import mod_crop +from r_basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + return_imgname(bool): Whether return image names. Default False. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + list[str]: Returned image name list. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + + if return_imgname: + imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] + return imgs, imgnames + else: + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.strip().split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.' + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py new file mode 100644 index 0000000000000000000000000000000000000000..697d35ccb560e902eee975c12ec728bea3b8b3c6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py @@ -0,0 +1,768 @@ +import cv2 +import math +import numpy as np +import random +import torch +from scipy import special +from scipy.stats import multivariate_normal +try: + from torchvision.transforms.functional_tensor import rgb_to_grayscale +except: + from torchvision.transforms.functional import rgb_to_grayscale + +# -------------------------------------------------------------------- # +# --------------------------- blur kernels --------------------------- # +# -------------------------------------------------------------------- # + + +# --------------------------- util functions --------------------------- # +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(d_matrix, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + + Args: + d_matrix (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, d_matrix) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a bivariate generalized Gaussian kernel. + Described in `Parameter Estimation For Multivariate Generalized + Gaussian Distributions`_ + by Pascal et. al (2013). + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + + .. _Parameter Estimation For Multivariate Generalized Gaussian + Distributions: https://arxiv.org/abs/1302.6498 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a plateau-like anisotropic kernel. + 1 / (1+x^(beta)) + + Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate generalized Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # assume beta_range[0] < 1 < beta_range[1] + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_plateau(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate plateau kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # TODO: this may be not proper + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False) + elif kernel_type == 'generalized_iso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=True) + elif kernel_type == 'generalized_aniso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=False) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False) + return kernel + + +np.seterr(divide='ignore', invalid='ignore') + + +def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): + """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + + Args: + cutoff (float): cutoff frequency in radians (pi is max) + kernel_size (int): horizontal and vertical size, must be odd. + pad_to (int): pad kernel size to desired size, must be odd or zero. + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + kernel = np.fromfunction( + lambda x, y: cutoff * special.j1(cutoff * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) + kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) + kernel = kernel / np.sum(kernel) + if pad_to > kernel_size: + pad_size = (pad_to - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + return kernel + + +# ------------------------------------------------------------- # +# --------------------------- noise --------------------------- # +# ------------------------------------------------------------- # + +# ----------------------- Gaussian Noise ----------------------- # + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. + return noise + + +def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): + """Add Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_gaussian_noise(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if not isinstance(sigma, (float, int)): + sigma = sigma.view(img.size(0), 1, 1, 1) + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + + if cal_gray_noise: + noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. + noise_gray = noise_gray.view(b, 1, h, w) + + # always calculate color noise + noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. + + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + return noise + + +def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_gaussian_noise_pt(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Gaussian Noise ----------------------- # +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): + sigma = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_gaussian_noise_pt(img, sigma, gray_noise) + + +def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Poisson (Shot) Noise ----------------------- # + + +def generate_poisson_noise(img, scale=1.0, gray_noise=False): + """Generate poisson noise. + + Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # round and clip image for counting vals correctly + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = len(np.unique(img)) + vals = 2**np.ceil(np.log2(vals)) + out = np.float32(np.random.poisson(img * vals) / float(vals)) + noise = out - img + if gray_noise: + noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) + return noise * scale + + +def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): + """Add poisson noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_poisson_noise(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): + """Generate a batch of poisson noise (PyTorch version) + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + if cal_gray_noise: + img_gray = rgb_to_grayscale(img, num_output_channels=1) + # round and clip image for counting vals correctly + img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img_gray * vals) / vals + noise_gray = out - img_gray + noise_gray = noise_gray.expand(b, 3, h, w) + + # always calculate color noise + # round and clip image for counting vals correctly + img = torch.clamp((img * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img * vals) / vals + noise = out - img + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + if not isinstance(scale, (float, int)): + scale = scale.view(b, 1, 1, 1) + return noise * scale + + +def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): + """Add poisson noise to a batch of images (PyTorch version). + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_poisson_noise_pt(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Poisson (Shot) Noise ----------------------- # + + +def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): + scale = np.random.uniform(scale_range[0], scale_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_poisson_noise(img, scale, gray_noise) + + +def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): + scale = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_poisson_noise_pt(img, scale, gray_noise) + + +def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ------------------------------------------------------------------------ # +# --------------------------- JPEG compression --------------------------- # +# ------------------------------------------------------------------------ # + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + _, encimg = cv2.imencode('.jpg', img * 255., encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255. + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100)): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = np.random.uniform(quality_range[0], quality_range[1]) + return add_jpg_compression(img, quality) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..48e2c3a0ce93dd8f8e2660422462b17b2fbd340c --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py @@ -0,0 +1,80 @@ +import random +import time +from os import path as osp +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from r_basicsr.data.transforms import augment +from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class FFHQDataset(data.Dataset): + """FFHQ dataset for StyleGAN. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + + """ + + def __init__(self, opt): + super(FFHQDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # FFHQ has 70000 images in total + self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + return {'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..480c7d2b2a41c728a40507ea88b14cf83dac6e75 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py @@ -0,0 +1,108 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from r_basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file +from r_basicsr.data.transforms import augment, paired_random_crop +from r_basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class PairedImageDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info_file': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] + img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] + + # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets + # TODO: It is better to update the datasets, rather than force to crop + if self.opt['phase'] != 'train': + img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ab7ba84b08f624cfc0e294c55066d70283f169 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py @@ -0,0 +1,193 @@ +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from torch.utils import data as data + +from r_basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from r_basicsr.data.transforms import augment +from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register(suffix='basicsr') +class RealESRGANDataset(data.Dataset): + """Dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(RealESRGANDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder = opt['dataroot_gt'] + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.gt_folder] + self.io_backend_opt['client_keys'] = ['gt'] + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip().split(' ')[0] for line in fin] + self.paths = [os.path.join(self.gt_folder, v) for v in paths] + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except (IOError, OSError) as e: + logger = get_root_logger() + logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # -------------------- Do augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = 400 + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_paired_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_paired_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..29bc126fc892310d8c3c13d38a077885189a54f9 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_paired_dataset.py @@ -0,0 +1,109 @@ +import os +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from r_basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb +from r_basicsr.data.transforms import augment, paired_random_crop +from r_basicsr.utils import FileClient, imfrombytes, img2tensor +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register(suffix='basicsr') +class RealESRGANPairedDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(RealESRGANPairedDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + # mean and std for normalizing the input images + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [] + for path in paths: + gt_path, lq_path = path.split(', ') + gt_path = os.path.join(self.gt_folder, gt_path) + lq_path = os.path.join(self.lq_folder, lq_path) + self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) + else: + # disk backend + # it will scan the whole folder to get meta info + # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/reds_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/reds_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7be362e00b83d60350a4c98a2d56a14aedf24ac5 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/reds_dataset.py @@ -0,0 +1,360 @@ +import numpy as np +import random +import torch +from pathlib import Path +from torch.utils import data as data + +from r_basicsr.data.transforms import augment, paired_random_crop +from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from r_basicsr.utils.flow_util import dequantize_flow +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class REDSDataset(data.Dataset): + """REDS dataset for training. + + The keys are generated from a meta info txt file. + basicsr/data/meta_info/meta_info_REDS_GT.txt + + Each line contains: + 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by + a white space. + Examples: + 000 100 (720,1280,3) + 001 100 (720,1280,3) + ... + + Key examples: "000/00000000" + GT (gt): Ground-Truth; + LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + dataroot_flow (str, optional): Data root path for flow. + meta_info_file (str): Path for meta information file. + val_partition (str): Validation partition types. 'REDS4' or + 'official'. + io_backend (dict): IO backend type and other kwarg. + + num_frame (int): Window size for input frames. + gt_size (int): Cropped patched size for gt patches. + interval_list (list): Interval list for temporal augmentation. + random_reverse (bool): Random reverse input frames. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + """ + + def __init__(self, opt): + super(REDSDataset, self).__init__() + self.opt = opt + self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) + self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None + assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}') + self.num_frame = opt['num_frame'] + self.num_half_frames = opt['num_frame'] // 2 + + self.keys = [] + with open(opt['meta_info_file'], 'r') as fin: + for line in fin: + folder, frame_num, _ = line.split(' ') + self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))]) + + # remove the video clips used in validation + if opt['val_partition'] == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif opt['val_partition'] == 'official': + val_partition = [f'{v:03d}' for v in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {opt["val_partition"]}.' + f"Supported ones are ['official', 'REDS4'].") + self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.is_lmdb = False + if self.io_backend_opt['type'] == 'lmdb': + self.is_lmdb = True + if self.flow_root is not None: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow'] + else: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + + # temporal augmentation configs + self.interval_list = opt['interval_list'] + self.random_reverse = opt['random_reverse'] + interval_str = ','.join(str(x) for x in opt['interval_list']) + logger = get_root_logger() + logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' + f'random reverse is {self.random_reverse}.') + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip_name, frame_name = key.split('/') # key example: 000/00000000 + center_frame_idx = int(frame_name) + + # determine the neighboring frames + interval = random.choice(self.interval_list) + + # ensure not exceeding the borders + start_frame_idx = center_frame_idx - self.num_half_frames * interval + end_frame_idx = center_frame_idx + self.num_half_frames * interval + # each clip has 100 frames starting from 0 to 99 + while (start_frame_idx < 0) or (end_frame_idx > 99): + center_frame_idx = random.randint(0, 99) + start_frame_idx = (center_frame_idx - self.num_half_frames * interval) + end_frame_idx = center_frame_idx + self.num_half_frames * interval + frame_name = f'{center_frame_idx:08d}' + neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval)) + # random reverse + if self.random_reverse and random.random() < 0.5: + neighbor_list.reverse() + + assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}') + + # get the GT frame (as the center frame) + if self.is_lmdb: + img_gt_path = f'{clip_name}/{frame_name}' + else: + img_gt_path = self.gt_root / clip_name / f'{frame_name}.png' + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + # get the neighboring LQ frames + img_lqs = [] + for neighbor in neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip_name}/{neighbor:08d}' + else: + img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + img_lqs.append(img_lq) + + # get flows + if self.flow_root is not None: + img_flows = [] + # read previous flows + for i in range(self.num_half_frames, 0, -1): + if self.is_lmdb: + flow_path = f'{clip_name}/{frame_name}_p{i}' + else: + flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png') + img_bytes = self.file_client.get(flow_path, 'flow') + cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] + dx, dy = np.split(cat_flow, 2, axis=0) + flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. + img_flows.append(flow) + # read next flows + for i in range(1, self.num_half_frames + 1): + if self.is_lmdb: + flow_path = f'{clip_name}/{frame_name}_n{i}' + else: + flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png') + img_bytes = self.file_client.get(flow_path, 'flow') + cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] + dx, dy = np.split(cat_flow, 2, axis=0) + flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. + img_flows.append(flow) + + # for random crop, here, img_flows and img_lqs have the same + # spatial size + img_lqs.extend(img_flows) + + # randomly crop + img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) + if self.flow_root is not None: + img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:] + + # augmentation - flip, rotate + img_lqs.append(img_gt) + if self.flow_root is not None: + img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows) + else: + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_lqs = torch.stack(img_results[0:-1], dim=0) + img_gt = img_results[-1] + + if self.flow_root is not None: + img_flows = img2tensor(img_flows) + # add the zero center flow + img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) + img_flows = torch.stack(img_flows, dim=0) + + # img_lqs: (t, c, h, w) + # img_flows: (t, 2, h, w) + # img_gt: (c, h, w) + # key: str + if self.flow_root is not None: + return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key} + else: + return {'lq': img_lqs, 'gt': img_gt, 'key': key} + + def __len__(self): + return len(self.keys) + + +@DATASET_REGISTRY.register() +class REDSRecurrentDataset(data.Dataset): + """REDS dataset for training recurrent networks. + + The keys are generated from a meta info txt file. + basicsr/data/meta_info/meta_info_REDS_GT.txt + + Each line contains: + 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by + a white space. + Examples: + 000 100 (720,1280,3) + 001 100 (720,1280,3) + ... + + Key examples: "000/00000000" + GT (gt): Ground-Truth; + LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + dataroot_flow (str, optional): Data root path for flow. + meta_info_file (str): Path for meta information file. + val_partition (str): Validation partition types. 'REDS4' or + 'official'. + io_backend (dict): IO backend type and other kwarg. + + num_frame (int): Window size for input frames. + gt_size (int): Cropped patched size for gt patches. + interval_list (list): Interval list for temporal augmentation. + random_reverse (bool): Random reverse input frames. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + """ + + def __init__(self, opt): + super(REDSRecurrentDataset, self).__init__() + self.opt = opt + self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) + self.num_frame = opt['num_frame'] + + self.keys = [] + with open(opt['meta_info_file'], 'r') as fin: + for line in fin: + folder, frame_num, _ = line.split(' ') + self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))]) + + # remove the video clips used in validation + if opt['val_partition'] == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif opt['val_partition'] == 'official': + val_partition = [f'{v:03d}' for v in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {opt["val_partition"]}.' + f"Supported ones are ['official', 'REDS4'].") + if opt['test_mode']: + self.keys = [v for v in self.keys if v.split('/')[0] in val_partition] + else: + self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.is_lmdb = False + if self.io_backend_opt['type'] == 'lmdb': + self.is_lmdb = True + if hasattr(self, 'flow_root') and self.flow_root is not None: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow'] + else: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + + # temporal augmentation configs + self.interval_list = opt.get('interval_list', [1]) + self.random_reverse = opt.get('random_reverse', False) + interval_str = ','.join(str(x) for x in self.interval_list) + logger = get_root_logger() + logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' + f'random reverse is {self.random_reverse}.') + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip_name, frame_name = key.split('/') # key example: 000/00000000 + + # determine the neighboring frames + interval = random.choice(self.interval_list) + + # ensure not exceeding the borders + start_frame_idx = int(frame_name) + if start_frame_idx > 100 - self.num_frame * interval: + start_frame_idx = random.randint(0, 100 - self.num_frame * interval) + end_frame_idx = start_frame_idx + self.num_frame * interval + + neighbor_list = list(range(start_frame_idx, end_frame_idx, interval)) + + # random reverse + if self.random_reverse and random.random() < 0.5: + neighbor_list.reverse() + + # get the neighboring LQ and GT frames + img_lqs = [] + img_gts = [] + for neighbor in neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip_name}/{neighbor:08d}' + img_gt_path = f'{clip_name}/{neighbor:08d}' + else: + img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' + img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png' + + # get LQ + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + img_lqs.append(img_lq) + + # get GT + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + img_gts.append(img_gt) + + # randomly crop + img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) + + # augmentation - flip, rotate + img_lqs.extend(img_gts) + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0) + img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0) + + # img_lqs: (t, c, h, w) + # img_gts: (t, c, h, w) + # key: str + return {'lq': img_lqs, 'gt': img_gts, 'key': key} + + def __len__(self): + return len(self.keys) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/single_image_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..02c501a85f04a43372b0e16f5d7fef52fd63427d --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/single_image_dataset.py @@ -0,0 +1,68 @@ +from os import path as osp +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from r_basicsr.data.data_util import paths_from_lmdb +from r_basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class SingleImageDataset(data.Dataset): + """Read only lq images in the test phase. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). + + There are two modes: + 1. 'meta_info_file': Use meta information file to generate paths. + 2. 'folder': Scan folders to generate paths. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + """ + + def __init__(self, opt): + super(SingleImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + self.lq_folder = opt['dataroot_lq'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder] + self.io_backend_opt['client_keys'] = ['lq'] + self.paths = paths_from_lmdb(self.lq_folder) + elif 'meta_info_file' in self.opt: + with open(self.opt['meta_info_file'], 'r') as fin: + self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] + else: + self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load lq image + lq_path = self.paths[index] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + return {'lq': img_lq, 'lq_path': lq_path} + + def __len__(self): + return len(self.paths) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/transforms.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbb5fb7daef5edfb425fafb4d67d471b3001e6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/transforms.py @@ -0,0 +1,179 @@ +import cv2 +import random +import torch + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/video_test_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/video_test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b730a677ca540fa429f7ea9c99c6113a384655d2 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/video_test_dataset.py @@ -0,0 +1,287 @@ +import glob +import torch +from os import path as osp +from torch.utils import data as data + +from r_basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class VideoTestDataset(data.Dataset): + """Video test dataset. + + Supported datasets: Vid4, REDS4, REDSofficial. + More generally, it supports testing dataset with following structures: + + dataroot + ├── subfolder1 + ├── frame000 + ├── frame001 + ├── ... + ├── subfolder1 + ├── frame000 + ├── frame001 + ├── ... + ├── ... + + For testing datasets, there is no need to prepare LMDB files. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + io_backend (dict): IO backend type and other kwarg. + cache_data (bool): Whether to cache testing datasets. + name (str): Dataset name. + meta_info_file (str): The path to the file storing the list of test + folders. If not provided, all the folders in the dataroot will + be used. + num_frame (int): Window size for input frames. + padding (str): Padding mode. + """ + + def __init__(self, opt): + super(VideoTestDataset, self).__init__() + self.opt = opt + self.cache_data = opt['cache_data'] + self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] + self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.' + + logger = get_root_logger() + logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') + self.imgs_lq, self.imgs_gt = {}, {} + if 'meta_info_file' in opt: + with open(opt['meta_info_file'], 'r') as fin: + subfolders = [line.split(' ')[0] for line in fin] + subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders] + subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders] + else: + subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) + subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*'))) + + if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']: + for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt): + # get frame list for lq and gt + subfolder_name = osp.basename(subfolder_lq) + img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True))) + img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True))) + + max_idx = len(img_paths_lq) + assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})' + f' and gt folders ({len(img_paths_gt)})') + + self.data_info['lq_path'].extend(img_paths_lq) + self.data_info['gt_path'].extend(img_paths_gt) + self.data_info['folder'].extend([subfolder_name] * max_idx) + for i in range(max_idx): + self.data_info['idx'].append(f'{i}/{max_idx}') + border_l = [0] * max_idx + for i in range(self.opt['num_frame'] // 2): + border_l[i] = 1 + border_l[max_idx - i - 1] = 1 + self.data_info['border'].extend(border_l) + + # cache data or save the frame list + if self.cache_data: + logger.info(f'Cache {subfolder_name} for VideoTestDataset...') + self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq) + self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt) + else: + self.imgs_lq[subfolder_name] = img_paths_lq + self.imgs_gt[subfolder_name] = img_paths_gt + else: + raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}') + + def __getitem__(self, index): + folder = self.data_info['folder'][index] + idx, max_idx = self.data_info['idx'][index].split('/') + idx, max_idx = int(idx), int(max_idx) + border = self.data_info['border'][index] + lq_path = self.data_info['lq_path'][index] + + select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) + + if self.cache_data: + imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx)) + img_gt = self.imgs_gt[folder][idx] + else: + img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]]) + img_gt.squeeze_(0) + + return { + 'lq': imgs_lq, # (t, c, h, w) + 'gt': img_gt, # (c, h, w) + 'folder': folder, # folder name + 'idx': self.data_info['idx'][index], # e.g., 0/99 + 'border': border, # 1 for border, 0 for non-border + 'lq_path': lq_path # center frame + } + + def __len__(self): + return len(self.data_info['gt_path']) + + +@DATASET_REGISTRY.register() +class VideoTestVimeo90KDataset(data.Dataset): + """Video test dataset for Vimeo90k-Test dataset. + + It only keeps the center frame for testing. + For testing datasets, there is no need to prepare LMDB files. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + io_backend (dict): IO backend type and other kwarg. + cache_data (bool): Whether to cache testing datasets. + name (str): Dataset name. + meta_info_file (str): The path to the file storing the list of test + folders. If not provided, all the folders in the dataroot will + be used. + num_frame (int): Window size for input frames. + padding (str): Padding mode. + """ + + def __init__(self, opt): + super(VideoTestVimeo90KDataset, self).__init__() + self.opt = opt + self.cache_data = opt['cache_data'] + if self.cache_data: + raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.') + self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] + self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} + neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.' + + logger = get_root_logger() + logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') + with open(opt['meta_info_file'], 'r') as fin: + subfolders = [line.split(' ')[0] for line in fin] + for idx, subfolder in enumerate(subfolders): + gt_path = osp.join(self.gt_root, subfolder, 'im4.png') + self.data_info['gt_path'].append(gt_path) + lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list] + self.data_info['lq_path'].append(lq_paths) + self.data_info['folder'].append('vimeo90k') + self.data_info['idx'].append(f'{idx}/{len(subfolders)}') + self.data_info['border'].append(0) + + def __getitem__(self, index): + lq_path = self.data_info['lq_path'][index] + gt_path = self.data_info['gt_path'][index] + imgs_lq = read_img_seq(lq_path) + img_gt = read_img_seq([gt_path]) + img_gt.squeeze_(0) + + return { + 'lq': imgs_lq, # (t, c, h, w) + 'gt': img_gt, # (c, h, w) + 'folder': self.data_info['folder'][index], # folder name + 'idx': self.data_info['idx'][index], # e.g., 0/843 + 'border': self.data_info['border'][index], # 0 for non-border + 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame + } + + def __len__(self): + return len(self.data_info['gt_path']) + + +@DATASET_REGISTRY.register() +class VideoTestDUFDataset(VideoTestDataset): + """ Video test dataset for DUF dataset. + + Args: + opt (dict): Config for train dataset. + Most of keys are the same as VideoTestDataset. + It has the following extra keys: + + use_duf_downsampling (bool): Whether to use duf downsampling to + generate low-resolution frames. + scale (bool): Scale, which will be added automatically. + """ + + def __getitem__(self, index): + folder = self.data_info['folder'][index] + idx, max_idx = self.data_info['idx'][index].split('/') + idx, max_idx = int(idx), int(max_idx) + border = self.data_info['border'][index] + lq_path = self.data_info['lq_path'][index] + + select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) + + if self.cache_data: + if self.opt['use_duf_downsampling']: + # read imgs_gt to generate low-resolution frames + imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx)) + imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) + else: + imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx)) + img_gt = self.imgs_gt[folder][idx] + else: + if self.opt['use_duf_downsampling']: + img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx] + # read imgs_gt to generate low-resolution frames + imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale']) + imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) + else: + img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale']) + img_gt.squeeze_(0) + + return { + 'lq': imgs_lq, # (t, c, h, w) + 'gt': img_gt, # (c, h, w) + 'folder': folder, # folder name + 'idx': self.data_info['idx'][index], # e.g., 0/99 + 'border': border, # 1 for border, 0 for non-border + 'lq_path': lq_path # center frame + } + + +@DATASET_REGISTRY.register() +class VideoRecurrentTestDataset(VideoTestDataset): + """Video test dataset for recurrent architectures, which takes LR video + frames as input and output corresponding HR video frames. + + Args: + Same as VideoTestDataset. + Unused opt: + padding (str): Padding mode. + + """ + + def __init__(self, opt): + super(VideoRecurrentTestDataset, self).__init__(opt) + # Find unique folder strings + self.folders = sorted(list(set(self.data_info['folder']))) + + def __getitem__(self, index): + folder = self.folders[index] + + if self.cache_data: + imgs_lq = self.imgs_lq[folder] + imgs_gt = self.imgs_gt[folder] + else: + raise NotImplementedError('Without cache_data is not implemented.') + + return { + 'lq': imgs_lq, + 'gt': imgs_gt, + 'folder': folder, + } + + def __len__(self): + return len(self.folders) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/data/vimeo90k_dataset.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/vimeo90k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de816999728d193b5b8c9a8d7fc02a8f77967a78 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/data/vimeo90k_dataset.py @@ -0,0 +1,192 @@ +import random +import torch +from pathlib import Path +from torch.utils import data as data + +from r_basicsr.data.transforms import augment, paired_random_crop +from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from r_basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class Vimeo90KDataset(data.Dataset): + """Vimeo90K dataset for training. + + The keys are generated from a meta info txt file. + basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt + + Each line contains: + 1. clip name; 2. frame number; 3. image shape, separated by a white space. + Examples: + 00001/0001 7 (256,448,3) + 00001/0002 7 (256,448,3) + + Key examples: "00001/0001" + GT (gt): Ground-Truth; + LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. + + The neighboring frame list for different num_frame: + num_frame | frame list + 1 | 4 + 3 | 3,4,5 + 5 | 2,3,4,5,6 + 7 | 1,2,3,4,5,6,7 + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + + num_frame (int): Window size for input frames. + gt_size (int): Cropped patched size for gt patches. + random_reverse (bool): Random reverse input frames. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + """ + + def __init__(self, opt): + super(Vimeo90KDataset, self).__init__() + self.opt = opt + self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) + + with open(opt['meta_info_file'], 'r') as fin: + self.keys = [line.split(' ')[0] for line in fin] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.is_lmdb = False + if self.io_backend_opt['type'] == 'lmdb': + self.is_lmdb = True + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + + # indices of input images + self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] + + # temporal augmentation configs + self.random_reverse = opt['random_reverse'] + logger = get_root_logger() + logger.info(f'Random reverse is {self.random_reverse}.') + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # random reverse + if self.random_reverse and random.random() < 0.5: + self.neighbor_list.reverse() + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip, seq = key.split('/') # key example: 00001/0001 + + # get the GT frame (im4.png) + if self.is_lmdb: + img_gt_path = f'{key}/im4' + else: + img_gt_path = self.gt_root / clip / seq / 'im4.png' + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + # get the neighboring LQ frames + img_lqs = [] + for neighbor in self.neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip}/{seq}/im{neighbor}' + else: + img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + img_lqs.append(img_lq) + + # randomly crop + img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) + + # augmentation - flip, rotate + img_lqs.append(img_gt) + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_lqs = torch.stack(img_results[0:-1], dim=0) + img_gt = img_results[-1] + + # img_lqs: (t, c, h, w) + # img_gt: (c, h, w) + # key: str + return {'lq': img_lqs, 'gt': img_gt, 'key': key} + + def __len__(self): + return len(self.keys) + + +@DATASET_REGISTRY.register() +class Vimeo90KRecurrentDataset(Vimeo90KDataset): + + def __init__(self, opt): + super(Vimeo90KRecurrentDataset, self).__init__(opt) + + self.flip_sequence = opt['flip_sequence'] + self.neighbor_list = [1, 2, 3, 4, 5, 6, 7] + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # random reverse + if self.random_reverse and random.random() < 0.5: + self.neighbor_list.reverse() + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip, seq = key.split('/') # key example: 00001/0001 + + # get the neighboring LQ and GT frames + img_lqs = [] + img_gts = [] + for neighbor in self.neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip}/{seq}/im{neighbor}' + img_gt_path = f'{clip}/{seq}/im{neighbor}' + else: + img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' + img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png' + # LQ + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + # GT + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + img_lqs.append(img_lq) + img_gts.append(img_gt) + + # randomly crop + img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) + + # augmentation - flip, rotate + img_lqs.extend(img_gts) + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_lqs = torch.stack(img_results[:7], dim=0) + img_gts = torch.stack(img_results[7:], dim=0) + + if self.flip_sequence: # flip the sequence: 7 frames to 14 frames + img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0) + img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0) + + # img_lqs: (t, c, h, w) + # img_gt: (c, h, w) + # key: str + return {'lq': img_lqs, 'gt': img_gts, 'key': key} + + def __len__(self): + return len(self.keys) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0350223ae30c092262bde2ed4ed817aeed04ed30 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/__init__.py @@ -0,0 +1,31 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.registry import LOSS_REGISTRY +from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty + +__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] + +# automatically scan and import loss modules for registry +# scan all the files under the 'losses' folder and collect files ending with '_loss.py' +loss_folder = osp.dirname(osp.abspath(__file__)) +loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] +# import all the loss modules +_model_modules = [importlib.import_module(f'r_basicsr.losses.{file_name}') for file_name in loss_filenames] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/basic_loss.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/basic_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0156fb7afd6f2ddeb2d3dea6c96eef04183da8ea --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/basic_loss.py @@ -0,0 +1,253 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.archs.vgg_arch import VGGFeatureExtractor +from r_basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + if reduction not in ['mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) + + def forward(self, pred, weight=None): + if weight is None: + y_weight = None + x_weight = None + else: + y_weight = weight[:, :, :-1, :] + x_weight = weight[:, :, :, :-1] + + y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) + x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/gan_loss.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/gan_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..88b4c196f0bc44cb3fb4822256fdf320a8f85203 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/gan_loss.py @@ -0,0 +1,208 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/loss_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fd293ff9e6a22814e5aeff6ae11fb54d2e4bafff --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/losses/loss_util.py @@ -0,0 +1,145 @@ +import functools +import torch +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper + + +def get_local_weights(residual, ksize): + """Get local weights for generating the artifact map of LDL. + + It is only called by the `get_refined_artifact_map` function. + + Args: + residual (Tensor): Residual between predicted and ground truth images. + ksize (Int): size of the local window. + + Returns: + Tensor: weight for each pixel to be discriminated as an artifact pixel + """ + + pad = (ksize - 1) // 2 + residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') + + unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) + pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) + + return pixel_level_weight + + +def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): + """Calculate the artifact map of LDL + (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) + + Args: + img_gt (Tensor): ground truth images. + img_output (Tensor): output images given by the optimizing model. + img_ema (Tensor): output images given by the ema model. + ksize (Int): size of the local window. + + Returns: + overall_weight: weight for each pixel to be discriminated as an artifact pixel + (calculated based on both local and global observations). + """ + + residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) + residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) + + patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) + pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) + overall_weight = patch_level_weight * pixel_level_weight + + overall_weight[residual_sr < residual_ema] = 0 + + return overall_weight diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f397f9de59d2ddc56d431e5a8f60dc16389d9105 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy + +from r_basicsr.utils.registry import METRIC_REGISTRY +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/fid.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..903ddf5810fe7dfa9959763ccfdaf1d01f12c6da --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/fid.py @@ -0,0 +1,93 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from r_basicsr.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/metric_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f39fee4701a6eedc615cc450ac93fd7c57157e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from r_basicsr.utils import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/niqe.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/niqe.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a004f85b6a8028c3c5504a62b1651664029c1b --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/niqe.py @@ -0,0 +1,197 @@ +import cv2 +import math +import numpy as np +import os +from scipy.ndimage.filters import convolve +from scipy.special import gamma + +from r_basicsr.metrics.metric_util import reorder_image, to_y_channel +from r_basicsr.utils.matlab_functions import imresize +from r_basicsr.utils.registry import METRIC_REGISTRY + + +def estimate_aggd_param(block): + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters. + + Args: + block (ndarray): 2D Image block. + + Returns: + tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD + distribution (Estimating the parames in Equation 7 in the paper). + """ + block = block.flatten() + gam = np.arange(0.2, 10.001, 0.001) # len = 9801 + gam_reciprocal = np.reciprocal(gam) + r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + + left_std = np.sqrt(np.mean(block[block < 0]**2)) + right_std = np.sqrt(np.mean(block[block > 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implementation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, + idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + + if scale == 1: + img = imresize(img / 255., scale=0.5, antialiasing=True) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))) + + quality = np.sqrt(quality) + quality = float(np.squeeze(quality)) + return quality + + +@METRIC_REGISTRY.register() +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296) + > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296) + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz')) + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + # round is necessary for being consistent with MATLAB's result + img = img.round() + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/niqe_pris_params.npz b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/niqe_pris_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..204ddcee87c4cd39aca04a42b539f0a5bfccecc3 Binary files /dev/null and b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/niqe_pris_params.npz differ diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/psnr_ssim.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..79636d7fc983262b23e91c4cd65230c39dcbb5a7 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/metrics/psnr_ssim.py @@ -0,0 +1,233 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from r_basicsr.metrics.metric_util import reorder_image, to_y_channel +from r_basicsr.utils.color_util import rgb2ycbcr_pt +from r_basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3096755c58766883a42406a1186e9c16f7de712 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/__init__.py @@ -0,0 +1,29 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'r_basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/base_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8c79c820dee5636dc4336dead2d9936aa698d881 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/base_model.py @@ -0,0 +1,380 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from r_basicsr.models import lr_scheduler as lr_scheduler +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warm-up. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warm-up iter numbers. -1 for no warm-up. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/edvr_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/edvr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..af2e302947083871922c145fe38b46580673c318 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/edvr_model.py @@ -0,0 +1,62 @@ +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class EDVRModel(VideoBaseModel): + """EDVR Model. + + Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 + """ + + def __init__(self, opt): + super(EDVRModel, self).__init__(opt) + if self.is_train: + self.train_tsa_iter = opt['train'].get('tsa_iter') + + def setup_optimizers(self): + train_opt = self.opt['train'] + dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') + if dcn_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate dcn params and normal params for different lr + normal_params = [] + dcn_params = [] + for name, param in self.net_g.named_parameters(): + if 'dcn' in name: + dcn_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': dcn_params, + 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.train_tsa_iter: + if current_iter == 1: + logger = get_root_logger() + logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'fusion' not in name: + param.requires_grad = False + elif current_iter == self.train_tsa_iter: + logger = get_root_logger() + logger.warning('Train all the parameters.') + for param in self.net_g.parameters(): + param.requires_grad = True + + super(EDVRModel, self).optimize_parameters(current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/esrgan_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/esrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ae51e979e4588f245f241a233060ac9a2a9731b0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/esrgan_model.py @@ -0,0 +1,83 @@ +import torch +from collections import OrderedDict + +from r_basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel + + +@MODEL_REGISTRY.register() +class ESRGANModel(SRGANModel): + """ESRGAN model for single image super-resolution.""" + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss (relativistic gan) + real_d_pred = self.net_d(self.gt).detach() + fake_g_pred = self.net_d(self.output) + l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) + l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # gan loss (relativistic gan) + + # In order to avoid the error in distributed training: + # "Error detected in CudnnBatchNormBackward: RuntimeError: one of + # the variables needed for gradient computation has been modified by + # an inplace operation", + # we separate the backwards for real and fake, and also detach the + # tensor for calculating mean. + + # real + fake_d_pred = self.net_d(self.output).detach() + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 + l_d_fake.backward() + self.optimizer_d.step() + + loss_dict['l_d_real'] = l_d_real + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/hifacegan_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/hifacegan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f22a13886db2bf53fc62380d63c4fd05ac4fb246 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/hifacegan_model.py @@ -0,0 +1,288 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import imwrite, tensor2img +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class HiFaceGANModel(SRModel): + """HiFaceGAN model for generic-purpose face restoration. + No prior modeling required, works for any degradations. + Currently doesn't support EMA for inference. + """ + + def init_training_settings(self): + + train_opt = self.opt['train'] + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass')) + + self.net_g.train() + + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # define losses + # HiFaceGAN does not use pixel loss by default + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('feature_matching_opt'): + self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device) + else: + self.cri_feat = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def discriminate(self, input_lq, output, ground_truth): + """ + This is a conditional (on the input) discriminator + In Batch Normalization, the fake and real images are + recommended to be in the same batch to avoid disparate + statistics in fake and real images. + So both fake and real images are fed to D all at once. + """ + h, w = output.shape[-2:] + if output.shape[-2:] != input_lq.shape[-2:]: + lq = torch.nn.functional.interpolate(input_lq, (h, w)) + real = torch.nn.functional.interpolate(ground_truth, (h, w)) + fake_concat = torch.cat([lq, output], dim=1) + real_concat = torch.cat([lq, real], dim=1) + else: + fake_concat = torch.cat([input_lq, output], dim=1) + real_concat = torch.cat([input_lq, ground_truth], dim=1) + + fake_and_real = torch.cat([fake_concat, real_concat], dim=0) + discriminator_out = self.net_d(fake_and_real) + pred_fake, pred_real = self._divide_pred(discriminator_out) + return pred_fake, pred_real + + @staticmethod + def _divide_pred(pred): + """ + Take the prediction of fake and real images from the combined batch. + The prediction contains the intermediate outputs of multiscale GAN, + so it's usually a list + """ + if type(pred) == list: + fake = [] + real = [] + for p in pred: + fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) + real.append([tensor[tensor.size(0) // 2:] for tensor in p]) + else: + fake = pred[:pred.size(0) // 2] + real = pred[pred.size(0) // 2:] + + return fake, real + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # Requires real prediction for feature matching loss + pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt) + l_g_gan = self.cri_gan(pred_fake, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # feature matching loss + if self.cri_feat: + l_g_feat = self.cri_feat(pred_fake, pred_real) + l_g_total += l_g_feat + loss_dict['l_g_feat'] = l_g_feat + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # TODO: Benchmark test between HiFaceGAN and SRGAN implementation: + # SRGAN use the same fake output for discriminator update + # while HiFaceGAN regenerate a new output using updated net_g + # This should not make too much difference though. Stick to SRGAN now. + # ------------------------------------------------------------------- + # ---------- Below are original HiFaceGAN code snippet -------------- + # ------------------------------------------------------------------- + # with torch.no_grad(): + # fake_image = self.net_g(self.lq) + # fake_image = fake_image.detach() + # fake_image.requires_grad_() + # pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt) + + # real + pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt) + l_d_real = self.cri_gan(pred_real, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + # fake + l_d_fake = self.cri_gan(pred_fake, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + + l_d_total = (l_d_real + l_d_fake) / 2 + l_d_total.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + print('HiFaceGAN does not support EMA now. pass') + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """ + Warning: HiFaceGAN requires train() mode even for validation + For more info, see https://github.com/Lotayou/Face-Renovation/issues/31 + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + + if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'): + self.net_g.train() + + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + print('In HiFaceGANModel: The new metrics package is under development.' + + 'Using super method now (Only PSNR & SSIM are supported)') + super().nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + """ + TODO: Validation using updated metric system + The metrics are now evaluated after all images have been tested + This allows batch processing, and also allows evaluation of + distributional metrics, such as: + + @ Frechet Inception Distance: FID + @ Maximum Mean Discrepancy: MMD + + Warning: + Need careful batch management for different inference settings. + + """ + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = dict() # {metric: 0 for metric in self.opt['val']['metrics'].keys()} + sr_tensors = [] + gt_tensors = [] + + pbar = tqdm(total=len(dataloader), unit='image') + for val_data in dataloader: + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() # detached cpu tensor, non-squeeze + sr_tensors.append(visuals['result']) + if 'gt' in visuals: + gt_tensors.append(visuals['gt']) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + + imwrite(tensor2img(visuals['result']), save_img_path) + + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + sr_pack = torch.cat(sr_tensors, dim=0) + gt_pack = torch.cat(gt_tensors, dim=0) + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + # The new metric caller automatically returns mean value + # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run + self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_) + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + print('HiFaceGAN does not support EMA now. Fallback to normal mode.') + + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/lr_scheduler.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/realesrgan_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/realesrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..54cde8f11bc489788d61acd61c5aaf71c51088e3 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/realesrgan_model.py @@ -0,0 +1,267 @@ +import numpy as np +import random +import torch +from collections import OrderedDict +from torch.nn import functional as F + +from r_basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from r_basicsr.data.transforms import paired_random_crop +from r_basicsr.losses.loss_util import get_refined_artifact_map +from r_basicsr.models.srgan_model import SRGANModel +from r_basicsr.utils import DiffJPEG, USMSharp +from r_basicsr.utils.img_process_util import filter2D +from r_basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRGANModel(SRGANModel): + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + # usm sharpening + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + if self.cri_ldl: + self.output_ema = self.net_g_ema(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + if self.cri_ldl: + pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7) + l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt)) + l_g_total += l_g_ldl + loss_dict['l_g_ldl'] = l_g_ldl + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/realesrnet_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/realesrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4108dd480588d7006e8e5400b58dccb6d22e632b --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/realesrnet_model.py @@ -0,0 +1,189 @@ +import numpy as np +import random +import torch +from torch.nn import functional as F + +from r_basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from r_basicsr.data.transforms import paired_random_crop +from r_basicsr.models.sr_model import SRModel +from r_basicsr.utils import DiffJPEG, USMSharp +from r_basicsr.utils.img_process_util import filter2D +from r_basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRNetModel(SRModel): + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM sharpen the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/sr_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..56322850990e7e2893a40b4e70e2ee6554ca9db3 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/sr_model.py @@ -0,0 +1,231 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import get_root_logger, imwrite, tensor2img +from r_basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/srgan_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/srgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..593bbb5f912989ad6d934a0d9b6ac4550b6b4c58 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/srgan_model.py @@ -0,0 +1,149 @@ +import torch +from collections import OrderedDict + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class SRGANModel(SRModel): + """SRGAN model for single image super-resolution.""" + + def init_training_settings(self): + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('ldl_opt'): + self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) + else: + self.cri_ldl = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/stylegan2_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/stylegan2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..24037cd0ff8708b1792662ede24af65fba5119db --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/stylegan2_model.py @@ -0,0 +1,283 @@ +import cv2 +import math +import numpy as np +import random +import torch +from collections import OrderedDict +from os import path as osp + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.losses.gan_loss import g_path_regularize, r1_penalty +from r_basicsr.utils import imwrite, tensor2img +from r_basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class StyleGAN2Model(BaseModel): + """StyleGAN2 model.""" + + def __init__(self, opt): + super(StyleGAN2Model, self).__init__(opt) + + # define network net_g + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + # latent dimension: self.num_style_feat + self.num_style_feat = opt['network_g']['num_style_feat'] + num_val_samples = self.opt['val'].get('num_val_samples', 16) + self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema only used for testing on one GPU and saving, do not need to + # wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # define losses + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.path_reg_weight = train_opt['path_reg_weight'] # for generator + + self.net_g_reg_every = train_opt['net_g_reg_every'] + self.net_d_reg_every = train_opt['net_d_reg_every'] + self.mixing_prob = train_opt['mixing_prob'] + + self.mean_path_length = 0 + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1) + if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC': + normal_params = [] + style_mlp_params = [] + modulation_conv_params = [] + for name, param in self.net_g.named_parameters(): + if 'modulation' in name: + normal_params.append(param) + elif 'style_mlp' in name: + style_mlp_params.append(param) + elif 'modulated_conv' in name: + modulation_conv_params.append(param) + else: + normal_params.append(param) + optim_params_g = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': style_mlp_params, + 'lr': train_opt['optim_g']['lr'] * 0.01 + }, + { + 'params': modulation_conv_params, + 'lr': train_opt['optim_g']['lr'] / 3 + } + ] + else: + normal_params = [] + for name, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # optimizer d + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC': + normal_params = [] + linear_params = [] + for name, param in self.net_d.named_parameters(): + if 'final_linear' in name: + linear_params.append(param) + else: + normal_params.append(param) + optim_params_d = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }, + { + 'params': linear_params, + 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512)) + } + ] + else: + normal_params = [] + for name, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + def feed_data(self, data): + self.real_img = data['gt'].to(self.device) + + def make_noise(self, batch, num_noise): + if num_noise == 1: + noises = torch.randn(batch, self.num_style_feat, device=self.device) + else: + noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0) + return noises + + def mixing_noise(self, batch, prob): + if random.random() < prob: + return self.make_noise(batch, 2) + else: + return [self.make_noise(batch, 1)] + + def optimize_parameters(self, current_iter): + loss_dict = OrderedDict() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + + batch = self.real_img.size(0) + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.net_g(noise) + fake_pred = self.net_d(fake_img.detach()) + + real_pred = self.net_d(self.real_img) + # wgan loss with softplus (logistic loss) for discriminator + l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In wgan, real_score should be positive and fake_score should be + # negative + loss_dict['real_score'] = real_pred.detach().mean() + loss_dict['fake_score'] = fake_pred.detach().mean() + l_d.backward() + + if current_iter % self.net_d_reg_every == 0: + self.real_img.requires_grad = True + real_pred = self.net_d(self.real_img) + l_d_r1 = r1_penalty(real_pred, self.real_img) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + # TODO: why do we need to add 0 * real_pred, otherwise, a runtime + # error will arise: RuntimeError: Expected to have finished + # reduction in the prior iteration before starting a new one. + # This error indicates that your module has parameters that were + # not used in producing loss. + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.net_g(noise) + fake_pred = self.net_d(fake_img) + + # wgan loss with softplus (non-saturating loss) for generator + l_g = self.cri_gan(fake_pred, True, is_disc=False) + loss_dict['l_g'] = l_g + l_g.backward() + + if current_iter % self.net_g_reg_every == 0: + path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink']) + noise = self.mixing_noise(path_batch_size, self.mixing_prob) + fake_img, latents = self.net_g(noise, return_latents=True) + l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length) + + l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0]) + # TODO: why do we need to add 0 * fake_img[0, 0, 0, 0] + l_g_path.backward() + loss_dict['l_g_path'] = l_g_path.detach().mean() + loss_dict['path_length'] = path_lengths + + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + def test(self): + with torch.no_grad(): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema([self.fixed_sample]) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + assert dataloader is None, 'Validation dataloader should be None.' + self.test() + result = tensor2img(self.output, min_max=(-1, 1)) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png') + imwrite(result, save_img_path) + # add sample images to tb_logger + result = (result / 255.).astype(np.float32) + result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) + if tb_logger is not None: + tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC') + + def save(self, epoch, current_iter): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/swinir_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/swinir_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7241324dc742379a0c246eec455d7a165bd8313d --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/swinir_model.py @@ -0,0 +1,33 @@ +import torch +from torch.nn import functional as F + +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class SwinIRModel(SRModel): + + def test(self): + # pad to multiplication of window_size + window_size = self.opt['network_g']['window_size'] + scale = self.opt.get('scale', 1) + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = self.lq.size() + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(img) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(img) + self.net_g.train() + + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_base_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..51a2eb89e1caabc4e007914414abb141852dd482 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_base_model.py @@ -0,0 +1,160 @@ +import torch +from collections import Counter +from os import path as osp +from torch import distributed as dist +from tqdm import tqdm + +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import get_root_logger, imwrite, tensor2img +from r_basicsr.utils.dist_util import get_dist_info +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class VideoBaseModel(SRModel): + """Base video SR model.""" + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset = dataloader.dataset + dataset_name = dataset.opt['name'] + with_metrics = self.opt['val']['metrics'] is not None + # initialize self.metric_results + # It is a dict: { + # 'folder1': tensor (num_frame x len(metrics)), + # 'folder2': tensor (num_frame x len(metrics)) + # } + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + rank, world_size = get_dist_info() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() + + metric_data = dict() + # record all frames (border and center frames) + if rank == 0: + pbar = tqdm(total=len(dataset), unit='frame') + for idx in range(rank, len(dataset), world_size): + val_data = dataset[idx] + val_data['lq'].unsqueeze_(0) + val_data['gt'].unsqueeze_(0) + folder = val_data['folder'] + frame_idx, max_idx = val_data['idx'].split('/') + lq_path = val_data['lq_path'] + + self.feed_data(val_data) + self.test() + visuals = self.get_current_visuals() + result_img = tensor2img([visuals['result']]) + metric_data['img'] = result_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + raise NotImplementedError('saving image is not supported during training.') + else: + if 'vimeo' in dataset_name.lower(): # vimeo90k dataset + split_result = lq_path.split('/') + img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}' + else: # other datasets, e.g., REDS, Vid4 + img_name = osp.splitext(osp.basename(lq_path))[0] + + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f'{img_name}_{self.opt["name"]}.png') + imwrite(result_img, save_img_path) + + if with_metrics: + # calculate metrics + for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()): + result = calculate_metric(metric_data, opt_) + self.metric_results[folder][int(frame_idx), metric_idx] += result + + # progress bar + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}') + if rank == 0: + pbar.close() + + if with_metrics: + if self.opt['dist']: + # collect data among GPUs + for _, tensor in self.metric_results.items(): + dist.reduce(tensor, 0) + dist.barrier() + else: + pass # assume use one gpu in non-dist testing + + if rank == 0: + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + logger = get_root_logger() + logger.warning('nondist_validation is not implemented. Run dist_validation.') + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + # ----------------- calculate the average values for each folder, and for each metric ----------------- # + # average all frames for each sub-folder + # metric_results_avg is a dict:{ + # 'folder1': tensor (len(metrics)), + # 'folder2': tensor (len(metrics)) + # } + metric_results_avg = { + folder: torch.mean(tensor, dim=0).cpu() + for (folder, tensor) in self.metric_results.items() + } + # total_avg_results is a dict: { + # 'metric1': float, + # 'metric2': float + # } + total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + for folder, tensor in metric_results_avg.items(): + for idx, metric in enumerate(total_avg_results.keys()): + total_avg_results[metric] += metric_results_avg[folder][idx].item() + # average among folders + for metric in total_avg_results.keys(): + total_avg_results[metric] /= len(metric_results_avg) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter) + + # ------------------------------------------ log the metric ------------------------------------------ # + log_str = f'Validation {dataset_name}\n' + for metric_idx, (metric, value) in enumerate(total_avg_results.items()): + log_str += f'\t # {metric}: {value:.4f}' + for folder, tensor in metric_results_avg.items(): + log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric_idx, (metric, value) in enumerate(total_avg_results.items()): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + for folder, tensor in metric_results_avg.items(): + tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_gan_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c98b11de249b4e62d1a177c9efcf4d420a94fa2e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_gan_model.py @@ -0,0 +1,17 @@ +from r_basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class VideoGANModel(SRGANModel, VideoBaseModel): + """Video GAN model. + + Use multiple inheritance. + It will first use the functions of SRGANModel: + init_training_settings + setup_optimizers + optimize_parameters + save + Then find functions in VideoBaseModel. + """ diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_recurrent_gan_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_recurrent_gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0c737baf75e4da84f0cda070bc8b1e03a88de76e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_recurrent_gan_model.py @@ -0,0 +1,180 @@ +import torch +from collections import OrderedDict + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_recurrent_model import VideoRecurrentModel + + +@MODEL_REGISTRY.register() +class VideoRecurrentGANModel(VideoRecurrentModel): + + def init_training_settings(self): + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # build network net_g with Exponential Moving Average (EMA) + # net_g_ema only used for testing on one GPU and saving. + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + if train_opt['fix_flow']: + normal_params = [] + flow_params = [] + for name, param in self.net_g.named_parameters(): + if 'spynet' in name: # The fix_flow now only works for spynet. + flow_params.append(param) + else: + normal_params.append(param) + + optim_params = [ + { # add flow params first + 'params': flow_params, + 'lr': train_opt['lr_flow'] + }, + { + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + ] + else: + optim_params = self.net_g.parameters() + + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + if self.fix_flow_iter: + if current_iter == 1: + logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'spynet' in name or 'edvr' in name: + param.requires_grad_(False) + elif current_iter == self.fix_flow_iter: + logger.warning('Train all the parameters.') + self.net_g.requires_grad_(True) + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + _, _, c, h, w = self.output.size() + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w)) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output.view(-1, c, h, w)) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + # reshape to (b*n, c, h, w) + real_d_pred = self.net_d(self.gt.view(-1, c, h, w)) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + # reshape to (b*n, c, h, w) + fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_recurrent_model.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_recurrent_model.py new file mode 100644 index 0000000000000000000000000000000000000000..10bc98defe3e39c4f9f0725616cf208f01ab597a --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/models/video_recurrent_model.py @@ -0,0 +1,197 @@ +import torch +from collections import Counter +from os import path as osp +from torch import distributed as dist +from tqdm import tqdm + +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import get_root_logger, imwrite, tensor2img +from r_basicsr.utils.dist_util import get_dist_info +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class VideoRecurrentModel(VideoBaseModel): + + def __init__(self, opt): + super(VideoRecurrentModel, self).__init__(opt) + if self.is_train: + self.fix_flow_iter = opt['train'].get('fix_flow') + + def setup_optimizers(self): + train_opt = self.opt['train'] + flow_lr_mul = train_opt.get('flow_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.') + if flow_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate flow params and normal params for different lr + normal_params = [] + flow_params = [] + for name, param in self.net_g.named_parameters(): + if 'spynet' in name: + flow_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': flow_params, + 'lr': train_opt['optim_g']['lr'] * flow_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.fix_flow_iter: + logger = get_root_logger() + if current_iter == 1: + logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'spynet' in name or 'edvr' in name: + param.requires_grad_(False) + elif current_iter == self.fix_flow_iter: + logger.warning('Train all the parameters.') + self.net_g.requires_grad_(True) + + super(VideoRecurrentModel, self).optimize_parameters(current_iter) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset = dataloader.dataset + dataset_name = dataset.opt['name'] + with_metrics = self.opt['val']['metrics'] is not None + # initialize self.metric_results + # It is a dict: { + # 'folder1': tensor (num_frame x len(metrics)), + # 'folder2': tensor (num_frame x len(metrics)) + # } + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + rank, world_size = get_dist_info() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() + + metric_data = dict() + num_folders = len(dataset) + num_pad = (world_size - (num_folders % world_size)) % world_size + if rank == 0: + pbar = tqdm(total=len(dataset), unit='folder') + # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded. + # (To avoid wait-dead) + for i in range(rank, num_folders + num_pad, world_size): + idx = min(i, num_folders - 1) + val_data = dataset[idx] + folder = val_data['folder'] + + # compute outputs + val_data['lq'].unsqueeze_(0) + val_data['gt'].unsqueeze_(0) + self.feed_data(val_data) + val_data['lq'].squeeze_(0) + val_data['gt'].squeeze_(0) + + self.test() + visuals = self.get_current_visuals() + + # tentative for out of GPU memory + del self.lq + del self.output + if 'gt' in visuals: + del self.gt + torch.cuda.empty_cache() + + if self.center_frame_only: + visuals['result'] = visuals['result'].unsqueeze(1) + if 'gt' in visuals: + visuals['gt'] = visuals['gt'].unsqueeze(1) + + # evaluate + if i < num_folders: + for idx in range(visuals['result'].size(1)): + result = visuals['result'][0, idx, :, :, :] + result_img = tensor2img([result]) # uint8, bgr + metric_data['img'] = result_img + if 'gt' in visuals: + gt = visuals['gt'][0, idx, :, :, :] + gt_img = tensor2img([gt]) # uint8, bgr + metric_data['img2'] = gt_img + + if save_img: + if self.opt['is_train']: + raise NotImplementedError('saving image is not supported during training.') + else: + if self.center_frame_only: # vimeo-90k + clip_ = val_data['lq_path'].split('/')[-3] + seq_ = val_data['lq_path'].split('/')[-2] + name_ = f'{clip_}_{seq_}' + img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f"{name_}_{self.opt['name']}.png") + else: # others + img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f"{idx:08d}_{self.opt['name']}.png") + # image name only for REDS dataset + imwrite(result_img, img_path) + + # calculate metrics + if with_metrics: + for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()): + result = calculate_metric(metric_data, opt_) + self.metric_results[folder][idx, metric_idx] += result + + # progress bar + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Folder: {folder}') + + if rank == 0: + pbar.close() + + if with_metrics: + if self.opt['dist']: + # collect data among GPUs + for _, tensor in self.metric_results.items(): + dist.reduce(tensor, 0) + dist.barrier() + + if rank == 0: + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def test(self): + n = self.lq.size(1) + self.net_g.eval() + + flip_seq = self.opt['val'].get('flip_seq', False) + self.center_frame_only = self.opt['val'].get('center_frame_only', False) + + if flip_seq: + self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1) + + with torch.no_grad(): + self.output = self.net_g(self.lq) + + if flip_seq: + output_1 = self.output[:, :n, :, :, :] + output_2 = self.output[:, n:, :, :, :].flip(1) + self.output = 0.5 * (output_1 + output_2) + + if self.center_frame_only: + self.output = self.output[:, n // 2, :, :, :] + + self.net_g.train() diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/deform_conv.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..6268ca825d59ef4a30d4d2156c4438cbbe9b3c1e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,379 @@ +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) +else: + try: + from . import deform_conv_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b465c493a3dd67d320b7a8997fbd501d2f89c807 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..98752dccf8c58817ca1a952554dd3f33188a2d34 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_ext.cpp b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41c6df6f721bd95a525fd6a03dd9882e863de042 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/fused_act.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..88edc445484b71119dc22a258e83aef49ce39b07 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import os +import torch +from torch import nn +from torch.autograd import Function + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) +else: + try: + from . import fused_act_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/src/fused_bias_act.cpp b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85ed0a79fb9c75f83470ac834090f03608d998ee --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c7ff53ce8306db2b3c582ec7fa6696a38b4df0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43d0b6783a5b512b55815a291fcac2bebeea31e0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8870063bae4468deab2e721f0978fe9facfb01b1 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/upfirdn2d.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d6122d59aa32fd52e956bd36200ba79af4a17b17 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,192 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import os +import torch +from torch.autograd import Function +from torch.nn import functional as F + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) +else: + try: + from . import upfirdn2d_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + _, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/test.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/test.py new file mode 100644 index 0000000000000000000000000000000000000000..eb402eca9ffcfa9c2e69976eb5c4fbd1743ea194 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/test.py @@ -0,0 +1,45 @@ +import logging +import torch +from os import path as osp + +from r_basicsr.data import build_dataloader, build_dataset +from r_basicsr.models import build_model +from r_basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs +from r_basicsr.utils.options import dict2str, parse_options + + +def test_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(root_path, is_train=False) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + # create model + model = build_model(opt) + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + test_pipeline(root_path) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/train.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f53d132e047b2bc4beec4a5a832b6dc407a981e2 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/train.py @@ -0,0 +1,215 @@ +import datetime +import logging +import math +import time +import torch +from os import path as osp + +from r_basicsr.data import build_dataloader, build_dataset +from r_basicsr.data.data_sampler import EnlargedSampler +from r_basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from r_basicsr.models import build_model +from r_basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, + init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) +from r_basicsr.utils.options import copy_opt_file, dict2str, parse_options + + +def init_tb_loggers(opt): + # initialize wandb logger before tensorboard logger to allow proper sync + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') + is not None) and ('debug' not in opt['name']): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name'])) + return tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loaders = None, [] + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase.split('_')[0] == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}') + val_loaders.append(val_loader) + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loaders, total_epochs, total_iters + + +def load_resume_state(opt): + resume_state_path = None + if opt['auto_resume']: + state_path = osp.join('experiments', opt['name'], 'training_states') + if osp.isdir(state_path): + states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) + if len(states) != 0: + states = [float(v.split('.state')[0]) for v in states] + resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') + opt['path']['resume_state'] = resume_state_path + else: + if opt['path'].get('resume_state'): + resume_state_path = opt['path']['resume_state'] + + if resume_state_path is None: + resume_state = None + else: + device_id = torch.cuda.current_device() + resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) + check_resume(opt, resume_state['iter']) + return resume_state + + +def train_pipeline(root_path): + # parse options, set distributed setting, set random seed + opt, args = parse_options(root_path, is_train=True) + opt['root_path'] = root_path + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + resume_state = load_resume_state(opt) + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name'])) + + # copy the yml file to the experiment root + copy_opt_file(args.opt, opt['path']['experiments_root']) + + # WARNING: should not use get_root_logger in the above codes, including the called functions + # Otherwise the logger will not be properly initialized + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + # initialize wandb and tb loggers + tb_logger = init_tb_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loaders, total_epochs, total_iters = result + + # create model + model = build_model(opt) + if resume_state: # resume training + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') + data_timer, iter_timer = AvgTimer(), AvgTimer() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_timer.record() + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_timer.record() + if current_iter == 1: + # reset start time in msg_logger for more accurate eta_time + # not work in resume mode + msg_logger.reset_start_time() + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): + if len(val_loaders) > 1: + logger.warning('Multiple validation datasets are *only* supported by SRModel.') + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + + data_timer.start() + iter_timer.start() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None: + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/__init__.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e999c8eddc6a5f9623863ce85232e58984e138 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/__init__.py @@ -0,0 +1,44 @@ +from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb +from .diffjpeg import DiffJPEG +from .file_client import FileClient +from .img_process_util import USMSharp, usm_sharp +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # color_util.py + 'bgr2ycbcr', + 'rgb2ycbcr', + 'rgb2ycbcr_pt', + 'ycbcr2bgr', + 'ycbcr2rgb', + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'AvgTimer', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt', + # diffjpeg + 'DiffJPEG', + # img_process_util + 'USMSharp', + 'usm_sharp' +] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/color_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/color_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4740d5c98dd0680654e20d46b81ab30dfe936d6e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/color_util.py @@ -0,0 +1,208 @@ +import numpy as np +import torch + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/diffjpeg.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/diffjpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..a481481cdddbd7bcbf25658d6375381cf7fc1a19 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/diffjpeg.py @@ -0,0 +1,515 @@ +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered + + +if __name__ == '__main__': + import cv2 + + from r_basicsr.utils import img2tensor, tensor2img + + img_gt = cv2.imread('test.png') / 255. + + # -------------- cv2 -------------- # + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] + _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) + img_lq = np.float32(cv2.imdecode(encimg, 1)) + cv2.imwrite('cv2_JPEG_20.png', img_lq) + + # -------------- DiffJPEG -------------- # + jpeger = DiffJPEG(differentiable=False).cuda() + img_gt = img2tensor(img_gt) + img_gt = torch.stack([img_gt, img_gt]).cuda() + quality = img_gt.new_tensor([20, 40]) + out = jpeger(img_gt, quality=quality) + + cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) + cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1])) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/dist_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/download_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6adda71320625242b0107f77d328e7afa236aee6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/download_util.py @@ -0,0 +1,99 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/file_client.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..89d83ab9e0d4314f8cdf2393908a561c6d1dca92 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/flow_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/flow_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7180b4e9b5c8f2eb36a9a0e4ff6affdaae84b8 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/flow_util.py @@ -0,0 +1,170 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val + + return dequantized_arr diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/img_process_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/img_process_util.py new file mode 100644 index 0000000000000000000000000000000000000000..52e02f09930dbf13bcd12bbe16b76e4fce52578e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/img_process_util.py @@ -0,0 +1,83 @@ +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + + +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +def usm_sharp(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. + + Input image: I; Blurry image: B. + 1. sharp = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * sharp + (1 - Mask) * I + + + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + sharp = img + weight * residual + sharp = np.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/img_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5f1da0911d9b12f9c6164df6c6e14e3c1aef88 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/img_util.py @@ -0,0 +1,172 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/lmdb_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/lmdb_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a10f60ffca2e36ac5f5564aafd70e79d06a723 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/logger.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..63d9fbe8bd7b7b8108cc4325661813c3ffda660b --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/logger.py @@ -0,0 +1,213 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + def reset_start_time(self): + self.start_time = time.time() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = get_root_logger() + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + from r_basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/matlab_functions.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a201f79aaf030cdba710dd97c28af1b29a93ed2a --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/matlab_functions.py @@ -0,0 +1,178 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + squeeze_flag = False + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if squeeze_flag: + out_2 = out_2.squeeze(0) + if numpy_type: + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + + return out_2 diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/misc.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d4a1403509672e85e74ac476e028cefb6dbb62 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/misc.py @@ -0,0 +1,141 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file size. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/options.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..f8da8792273689bae7e6f0830b631b2d7d667414 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/options.py @@ -0,0 +1,194 @@ +import argparse +import random +import torch +import yaml +from collections import OrderedDict +from os import path as osp + +from r_basicsr.utils import set_random_seed +from r_basicsr.utils.dist_util import get_dist_info, init_dist, master_only + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + with open(args.opt, mode='r') as f: + opt = yaml.load(f, Loader=ordered_yaml()[0]) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/plot_util.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/plot_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e5a3ffad2cc851d5755d5d62efc290320901b6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/plot_util.py @@ -0,0 +1,84 @@ +import re + + +def read_data_from_tensorboard(log_path, tag): + """Get raw data (steps and values) from tensorboard events. + + Args: + log_path (str): Path to the tensorboard log. + tag (str): tag to be read. + """ + from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + # tensorboard event + event_acc = EventAccumulator(log_path) + event_acc.Reload() + scalar_list = event_acc.Tags()['scalars'] + print('tag list: ', scalar_list) + steps = [int(s.step) for s in event_acc.Scalars(tag)] + values = [s.value for s in event_acc.Scalars(tag)] + return steps, values + + +def read_data_from_txt_2v(path, pattern, step_one=False): + """Read data from txt with 2 returned values (usually [step, value]). + + Args: + path (str): path to the txt file. + pattern (str): re (regular expression) pattern. + step_one (bool): add 1 to steps. Default: False. + """ + with open(path) as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + steps = [] + values = [] + + pattern = re.compile(pattern) + for line in lines: + match = pattern.match(line) + if match: + steps.append(int(match.group(1))) + values.append(float(match.group(2))) + if step_one: + steps = [v + 1 for v in steps] + return steps, values + + +def read_data_from_txt_1v(path, pattern): + """Read data from txt with 1 returned values. + + Args: + path (str): path to the txt file. + pattern (str): re (regular expression) pattern. + """ + with open(path) as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + data = [] + + pattern = re.compile(pattern) + for line in lines: + match = pattern.match(line) + if match: + data.append(float(match.group(1))) + return data + + +def smooth_data(values, smooth_weight): + """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). + + Ref: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/\ + tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 + + Args: + values (list): A list of values to be smoothed. + smooth_weight (float): Smooth weight. + """ + values_sm = [] + last_sm_value = values[0] + for value in values: + value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value + values_sm.append(value_sm) + last_sm_value = value_sm + return values_sm diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/registry.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5e72ef7ff21b94f50e6caa8948f69ca0b04bc968 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/utils/registry.py @@ -0,0 +1,88 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None, suffix=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class, suffix) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj, suffix) + + def get(self, name, suffix='basicsr'): + ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/custom_nodes/ComfyUI-ReActor/r_basicsr/version.py b/custom_nodes/ComfyUI-ReActor/r_basicsr/version.py new file mode 100644 index 0000000000000000000000000000000000000000..822285f560ebc3ea230bff38935b4290a6c1acf9 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_basicsr/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Wed Apr 5 00:20:48 2023 +__version__ = '1.4.2' +__gitsha__ = 'unknown' +version_info = (1, 4, 2) diff --git a/custom_nodes/ComfyUI-ReActor/r_chainner/archs/face/gfpganv1_clean_arch.py b/custom_nodes/ComfyUI-ReActor/r_chainner/archs/face/gfpganv1_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..05fcb7237845a7a2ea496cab78c2546d9d32487d --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_chainner/archs/face/gfpganv1_clean_arch.py @@ -0,0 +1,370 @@ +# pylint: skip-file +# type: ignore +import math +import random + +import torch +from torch import nn +from torch.nn import functional as F + +from r_chainner.archs.face.stylegan2_clean_arch import StyleGAN2GeneratorClean + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + narrow=1, + sft_half=False, + ): + super(StyleGAN2GeneratorCSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + ) + self.sft_half = sft_half + + def forward( + self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False, + ): + """Forward function for StyleGAN2GeneratorCSFT. + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [ + getattr(self.noises, f"noise{i}") for i in range(self.num_layers) + ] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append( + truncation_latent + truncation * (style - truncation_latent) + ) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = ( + styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + ) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs, + ): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ResBlock(nn.Module): + """Residual block with bilinear upsampling/downsampling. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + + def __init__(self, in_channels, out_channels, mode="down"): + super(ResBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == "down": + self.scale_factor = 0.5 + elif mode == "up": + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate( + out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate( + x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + skip = self.skip(x) + out = out + skip + return out + + +class GFPGANv1Clean(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + state_dict, + ): + super(GFPGANv1Clean, self).__init__() + + out_size = 512 + num_style_feat = 512 + channel_multiplier = 2 + decoder_load_path = None + fix_decoder = False + num_mlp = 8 + input_is_latent = True + different_w = True + narrow = 1 + sft_half = True + + self.model_arch = "GFPGAN" + self.sub_type = "Face SR" + self.scale = 8 + self.in_nc = 3 + self.out_nc = 3 + self.state = state_dict + + self.supports_fp16 = False + self.supports_bf16 = True + self.min_size_restriction = 512 + + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + "4": int(512 * unet_narrow), + "8": int(512 * unet_narrow), + "16": int(512 * unet_narrow), + "32": int(512 * unet_narrow), + "64": int(256 * channel_multiplier * unet_narrow), + "128": int(128 * channel_multiplier * unet_narrow), + "256": int(64 * channel_multiplier * unet_narrow), + "512": int(32 * channel_multiplier * unet_narrow), + "1024": int(16 * channel_multiplier * unet_narrow), + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2 ** (int(math.log(out_size, 2))) + + self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1) + + # downsample + in_channels = channels[f"{first_out_size}"] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f"{2**(i - 1)}"] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down")) + in_channels = out_channels + + self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1) + + # upsample + in_channels = channels["4"] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up")) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorCSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + sft_half=sft_half, + ) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load( + decoder_load_path, map_location=lambda storage, loc: storage + )["params_ema"] + ) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1), + ) + ) + self.condition_shift.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1), + ) + ) + self.load_state_dict(state_dict) + + def forward( + self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs + ): + """Forward function for GFPGANv1Clean. + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder( + [style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise, + ) + + return image, out_rgbs diff --git a/custom_nodes/ComfyUI-ReActor/r_chainner/archs/face/stylegan2_clean_arch.py b/custom_nodes/ComfyUI-ReActor/r_chainner/archs/face/stylegan2_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..c48de9af6904b8d1891a84efa8e4d76104d5d710 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_chainner/archs/face/stylegan2_clean_arch.py @@ -0,0 +1,453 @@ +# pylint: skip-file +# type: ignore +import math + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +class NormStyleCode(nn.Module): + def forward(self, x): + """Normalize the style codes. + Args: + x (Tensor): Style codes with shape (b, c). + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + There is no bias in ModulatedConv2d. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + ): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights( + self.modulation, + scale=1, + bias_fill=1, + a=0, + mode="fan_in", + nonlinearity="linear", + ) + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) + / math.sqrt(in_channels * kernel_size**2) + ) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view( + b * self.out_channels, c, self.kernel_size, self.kernel_size + ) + + # upsample or downsample if necessary + if self.sample_mode == "upsample": + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) + elif self.sample_mode == "downsample": + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})" + ) + + +class StyleConv(nn.Module): + """Style conv used in StyleGAN2. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + ): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + ) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB (image space) from features. + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + ) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode="bilinear", align_corners=False + ) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +class StyleGAN2GeneratorClean(nn.Module): + """Clean version of StyleGAN2 Generator. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__( + self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1 + ): + super(StyleGAN2GeneratorClean, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend( + [ + nn.Linear(num_style_feat, num_style_feat, bias=True), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ] + ) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # initialization + default_init_weights( + self.style_mlp, + scale=1, + bias_fill=0, + a=0.2, + mode="fan_in", + nonlinearity="leaky_relu", + ) + + # channel list + channels = { + "4": int(512 * narrow), + "8": int(512 * narrow), + "16": int(512 * narrow), + "32": int(512 * narrow), + "64": int(256 * channel_multiplier * narrow), + "128": int(128 * channel_multiplier * narrow), + "256": int(64 * channel_multiplier * narrow), + "512": int(32 * channel_multiplier * narrow), + "1024": int(16 * channel_multiplier * narrow), + } + self.channels = channels + + self.constant_input = ConstantInput(channels["4"], size=4) + self.style_conv1 = StyleConv( + channels["4"], + channels["4"], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + ) + self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels["4"] + # noise + for layer_idx in range(self.num_layers): + resolution = 2 ** ((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode="upsample", + ) + ) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + ) + ) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn( + num_latent, self.num_style_feat, device=self.constant_input.weight.device + ) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward( + self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False, + ): + """Forward function for StyleGAN2GeneratorClean. + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [ + getattr(self.noises, f"noise{i}") for i in range(self.num_layers) + ] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append( + truncation_latent + truncation * (style - truncation_latent) + ) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = ( + styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + ) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs, + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None diff --git a/custom_nodes/ComfyUI-ReActor/r_chainner/model_loading.py b/custom_nodes/ComfyUI-ReActor/r_chainner/model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..8f4e75192f758c887b0e0e256a54dfa8f8acbe1e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_chainner/model_loading.py @@ -0,0 +1,28 @@ +from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean +from r_chainner.types import PyTorchModel + + +class UnsupportedModel(Exception): + pass + + +def load_state_dict(state_dict) -> PyTorchModel: + + state_dict_keys = list(state_dict.keys()) + + if "params_ema" in state_dict_keys: + state_dict = state_dict["params_ema"] + elif "params-ema" in state_dict_keys: + state_dict = state_dict["params-ema"] + elif "params" in state_dict_keys: + state_dict = state_dict["params"] + + state_dict_keys = list(state_dict.keys()) + + # GFPGAN + if ( + "toRGB.0.weight" in state_dict_keys + and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys + ): + model = GFPGANv1Clean(state_dict) + return model diff --git a/custom_nodes/ComfyUI-ReActor/r_chainner/types.py b/custom_nodes/ComfyUI-ReActor/r_chainner/types.py new file mode 100644 index 0000000000000000000000000000000000000000..965a978c4c7bde38351a053cdad54e5eac3c96df --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_chainner/types.py @@ -0,0 +1,18 @@ +from typing import Union + +from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean + + +PyTorchFaceModels = (GFPGANv1Clean,) +PyTorchFaceModel = Union[GFPGANv1Clean] + + +def is_pytorch_face_model(model: object): + return isinstance(model, PyTorchFaceModels) + +PyTorchModels = (*PyTorchFaceModels, ) +PyTorchModel = Union[PyTorchFaceModel] + + +def is_pytorch_model(model: object): + return isinstance(model, PyTorchModels) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47c76ab3a00100166aec9b031dd2ce771366af55 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/__init__.py @@ -0,0 +1,102 @@ +import os +import torch +from torch import nn +from copy import deepcopy +import pathlib + +from r_facelib.utils import load_file_from_url +from r_facelib.utils import download_pretrained_models +from r_facelib.detection.yolov5face.models.common import Conv + +from .retinaface.retinaface import RetinaFace +from .yolov5face.face_detector import YoloDetector + + +def init_detection_model(model_name, half=False, device='cuda'): + if 'retinaface' in model_name: + model = init_retinaface_model(model_name, half, device) + elif 'YOLOv5' in model_name: + model = init_yolov5face_model(model_name, device) + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + return model + + +def init_retinaface_model(model_name, half=False, device='cuda'): + if model_name == 'retinaface_resnet50': + model = RetinaFace(network_name='resnet50', half=half) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth' + elif model_name == 'retinaface_mobile0.25': + model = RetinaFace(network_name='mobile0.25', half=half) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + + return model + + +def init_yolov5face_model(model_name, device='cuda'): + current_dir = str(pathlib.Path(__file__).parent.resolve()) + if model_name == 'YOLOv5l': + model = YoloDetector(config_name=current_dir+'/yolov5face/models/yolov5l.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' + elif model_name == 'YOLOv5n': + model = YoloDetector(config_name=current_dir+'/yolov5face/models/yolov5n.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.detector.load_state_dict(load_net, strict=True) + model.detector.eval() + model.detector = model.detector.to(device).float() + + for m in model.detector.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif isinstance(m, Conv): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + return model + + +# Download from Google Drive +# def init_yolov5face_model(model_name, device='cuda'): +# if model_name == 'YOLOv5l': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) +# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} +# elif model_name == 'YOLOv5n': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) +# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} +# else: +# raise NotImplementedError(f'{model_name} is not implemented.') + +# model_path = os.path.join('../../models/facedetection', list(f_id.keys())[0]) +# if not os.path.exists(model_path): +# download_pretrained_models(file_ids=f_id, save_path_root='../../models/facedetection') + +# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) +# model.detector.load_state_dict(load_net, strict=True) +# model.detector.eval() +# model.detector = model.detector.to(device).float() + +# for m in model.detector.modules(): +# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: +# m.inplace = True # pytorch 1.7.0 compatibility +# elif isinstance(m, Conv): +# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + +# return model \ No newline at end of file diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/align_trans.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/align_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..07f1eb365462c2ec5bbac6d1854c786b6fd6be90 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/align_trans.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], + [33.54930115, 92.3655014], [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): + + return tmp_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException('facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/matlab_cp2tform.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/matlab_cp2tform.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a8b54a91709c71437e15c68d3be9a9b0a20a34 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/matlab_cp2tform.py @@ -0,0 +1,317 @@ +import numpy as np +from numpy.linalg import inv, lstsq +from numpy.linalg import matrix_rank as rank +from numpy.linalg import norm + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=-1) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + T = inv(Tinv) + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..35829e019dc9b93f58fd0ad9fb8f0bb1b5acb5c6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface.py @@ -0,0 +1,389 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter + +from modules import shared + +from r_facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face +from r_facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head +from r_facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, + py_cpu_nms) + +#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +# elif hasattr(torch,'dml'): +# device = torch.device('dml') +elif hasattr(torch,'dml') or hasattr(torch,'privateuseone'): # AMD + if shared.cmd_opts is not None: # A1111 + if shared.cmd_opts.device_id is not None: + device = torch.device(f'privateuseone:{shared.cmd_opts.device_id}') + else: + device = torch.device('privateuseone:0') + else: + device = torch.device('privateuseone:0') +else: + device = torch.device('cpu') + + +def generate_config(network_name): + + cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 640, + 'return_layers': { + 'stage1': 1, + 'stage2': 2, + 'stage3': 3 + }, + 'in_channel': 32, + 'out_channel': 64 + } + + cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 24, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'image_size': 840, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 + } + + if network_name == 'mobile0.25': + return cfg_mnet + elif network_name == 'resnet50': + return cfg_re50 + else: + raise NotImplementedError(f'network_name={network_name}') + + +class RetinaFace(nn.Module): + + def __init__(self, network_name='resnet50', half=False, phase='test'): + super(RetinaFace, self).__init__() + self.half_inference = half + cfg = generate_config(network_name) + self.backbone = cfg['name'] + + self.model_name = f'retinaface_{network_name}' + self.cfg = cfg + self.phase = phase + self.target_size, self.max_size = 1600, 2150 + self.resize, self.scale, self.scale1 = 1., None, None + self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) + self.reference = get_reference_facial_points(default_square=True) + # Build network. + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=False) + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + self.to(device) + self.eval() + if self.half_inference: + self.half() + + def forward(self, inputs): + self.to(device) + out = self.body(inputs) + + if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': + out = list(out.values()) + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) + tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = (torch.cat(tmp, dim=1)) + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output + + def __detect_faces(self, inputs): + # get scale + height, width = inputs.shape[2:] + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + tmp = [width, height, width, height, width, height, width, height, width, height] + self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + + # forawrd + inputs = inputs.to(device) + if self.half_inference: + inputs = inputs.half() + loc, conf, landmarks = self(inputs) + + # get priorbox + priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) + priors = priorbox.forward().to(device) + + return loc, conf, landmarks, priors + + # single image detection + def transform(self, image, use_origin_size): + # convert to opencv format + if isinstance(image, Image.Image): + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = image.astype(np.float32) + + # testing scale + im_size_min = np.min(image.shape[0:2]) + im_size_max = np.max(image.shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + # convert to torch.tensor format + # image -= (104, 117, 123) + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).unsqueeze(0) + + return image, resize + + def detect_faces( + self, + image, + conf_threshold=0.8, + nms_threshold=0.4, + use_origin_size=True, + ): + """ + Params: + imgs: BGR image + """ + image, self.resize = self.transform(image, use_origin_size) + image = image.to(device) + if self.half_inference: + image = image.half() + image = image - self.mean_tensor + + loc, conf, landmarks, priors = self.__detect_faces(image) + + boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) + boxes = boxes * self.scale / self.resize + boxes = boxes.cpu().numpy() + + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) + landmarks = landmarks * self.scale1 / self.resize + landmarks = landmarks.cpu().numpy() + + # ignore low scores + inds = np.where(scores > conf_threshold)[0] + boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] + + # sort + order = scores.argsort()[::-1] + boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] + + # do NMS + bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] + # self.t['forward_pass'].toc() + # print(self.t['forward_pass'].average_time) + # import sys + # sys.stdout.flush() + return np.concatenate((bounding_boxes, landmarks), axis=1) + + def __align_multi(self, image, boxes, landmarks, limit=None): + + if len(boxes) < 1: + return [], [] + + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + + faces = [] + for landmark in landmarks: + facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] + + warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) + faces.append(warped_face) + + return np.concatenate((boxes, landmarks), axis=1), faces + + def align_multi(self, img, conf_threshold=0.8, limit=None): + + rlt = self.detect_faces(img, conf_threshold=conf_threshold) + boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] + + return self.__align_multi(img, boxes, landmarks, limit) + + # batched detection + def batched_transform(self, frames, use_origin_size): + """ + Arguments: + frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], + type=np.float32, BGR format). + use_origin_size: whether to use origin size. + """ + from_PIL = True if isinstance(frames[0], Image.Image) else False + + # convert to opencv format + if from_PIL: + frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] + frames = np.asarray(frames, dtype=np.float32) + + # testing scale + im_size_min = np.min(frames[0].shape[0:2]) + im_size_max = np.max(frames[0].shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + if not from_PIL: + frames = F.interpolate(frames, scale_factor=resize) + else: + frames = [ + cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + for frame in frames + ] + + # convert to torch.tensor format + if not from_PIL: + frames = frames.transpose(1, 2).transpose(1, 3).contiguous() + else: + frames = frames.transpose((0, 3, 1, 2)) + frames = torch.from_numpy(frames) + + return frames, resize + + def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): + """ + Arguments: + frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], + type=np.uint8, BGR format). + conf_threshold: confidence threshold. + nms_threshold: nms threshold. + use_origin_size: whether to use origin size. + Returns: + final_bounding_boxes: list of np.array ([n_boxes, 5], + type=np.float32). + final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). + """ + # self.t['forward_pass'].tic() + frames, self.resize = self.batched_transform(frames, use_origin_size) + frames = frames.to(device) + frames = frames - self.mean_tensor + + b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) + + final_bounding_boxes, final_landmarks = [], [] + + # decode + priors = priors.unsqueeze(0) + b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize + b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize + b_conf = b_conf[:, :, 1] + + # index for selection + b_indice = b_conf > conf_threshold + + # concat + b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() + + for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): + + # ignore low scores + pred, landm = pred[inds, :], landm[inds, :] + if pred.shape[0] == 0: + final_bounding_boxes.append(np.array([], dtype=np.float32)) + final_landmarks.append(np.array([], dtype=np.float32)) + continue + + # sort + # order = score.argsort(descending=True) + # box, landm, score = box[order], landm[order], score[order] + + # to CPU + bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() + + # NMS + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] + + # append + final_bounding_boxes.append(bounding_boxes) + final_landmarks.append(landmarks) + # self.t['forward_pass'].toc(average=True) + # self.batch_time += self.t['forward_pass'].diff + # self.total_frame += len(frames) + # print(self.batch_time / self.total_frame) + + return final_bounding_boxes, final_landmarks diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface_net.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface_net.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6aa82d3e9055a838f1f9076b12f05fdfc154d0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface_net.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + # input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + +def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + +def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface_utils.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c357757741c6d9bd7ce4d8ce740fefd51850fbf --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/retinaface/retinaface_utils.py @@ -0,0 +1,421 @@ +import numpy as np +import torch +import torchvision +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + keep = torchvision.ops.nms( + boxes=torch.Tensor(dets[:, :4]), + scores=torch.Tensor(dets[:, 4]), + iou_threshold=thresh, + ) + + return list(keep) + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + ( + boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:] / 2), + 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy + boxes[:, 2:] - boxes[:, :2], + 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when matching boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ encoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ encoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence + 3)landm preds. + """ + # jaccard index + overlaps = jaccard(truths, point_form(priors)) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + tmp = ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ) + landms = torch.cat(tmp, dim=1) + return landms + + +def batched_decode(b_loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + b_loc (tensor): location predictions for loc layers, + Shape: [num_batches,num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + boxes = ( + priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), + ) + boxes = torch.cat(boxes, dim=2) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +def batched_decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_batches,num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = ( + priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], + ) + landms = torch.cat(landms, dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/face_detector.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..920e6af71728195abe06dbdb2965ce04a1931f07 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/face_detector.py @@ -0,0 +1,141 @@ +import copy +from pathlib import Path + +import cv2 +import numpy as np +import torch +from torch import torch_version + +from r_facelib.detection.yolov5face.models.common import Conv +from r_facelib.detection.yolov5face.models.yolo import Model +from r_facelib.detection.yolov5face.utils.datasets import letterbox +from r_facelib.detection.yolov5face.utils.general import ( + check_img_size, + non_max_suppression_face, + scale_coords, + scale_coords_landmarks, +) + +print(f"Torch version: {torch.__version__}") +IS_HIGH_VERSION = torch_version.__version__ >= "1.9.0" + +def isListempty(inList): + if isinstance(inList, list): # Is a list + return all(map(isListempty, inList)) + return False # Not a list + +class YoloDetector: + def __init__( + self, + config_name, + min_face=10, + target_size=None, + device='cuda', + ): + """ + config_name: name of .yaml config with network configuration from models/ folder. + min_face : minimal face size in pixels. + target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. + None for original resolution. + """ + self._class_path = Path(__file__).parent.absolute() + self.target_size = target_size + self.min_face = min_face + self.detector = Model(cfg=config_name) + self.device = device + + + def _preprocess(self, imgs): + """ + Preprocessing image before passing through the network. Resize and conversion to torch tensor. + """ + pp_imgs = [] + for img in imgs: + h0, w0 = img.shape[:2] # orig hw + if self.target_size: + r = self.target_size / min(h0, w0) # resize image to img_size + if r < 1: + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + + imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size + img = letterbox(img, new_shape=imgsz)[0] + pp_imgs.append(img) + pp_imgs = np.array(pp_imgs) + pp_imgs = pp_imgs.transpose(0, 3, 1, 2) + pp_imgs = torch.from_numpy(pp_imgs).to(self.device) + pp_imgs = pp_imgs.float() # uint8 to fp16/32 + return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 + + def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): + """ + Postprocessing of raw pytorch model output. + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + bboxes = [[] for _ in range(len(origimgs))] + landmarks = [[] for _ in range(len(origimgs))] + + pred = non_max_suppression_face(pred, conf_thres, iou_thres) + + for image_id, origimg in enumerate(origimgs): + img_shape = origimg.shape + image_height, image_width = img_shape[:2] + gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh + gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks + det = pred[image_id].cpu() + scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() + scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() + + for j in range(det.size()[0]): + box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() + box = list( + map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) + ) + if box[3] - box[1] < self.min_face: + continue + lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() + lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) + lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] + bboxes[image_id].append(box) + landmarks[image_id].append(lm) + return bboxes, landmarks + + def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): + """ + Get bbox coordinates and keypoints of faces on original image. + Params: + imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) + conf_thres: confidence threshold for each prediction + iou_thres: threshold for NMS (filter of intersecting bboxes) + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + # Pass input images through face detector + images = imgs if isinstance(imgs, list) else [imgs] + images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + origimgs = copy.deepcopy(images) + + images = self._preprocess(images) + + if IS_HIGH_VERSION: + with torch.inference_mode(): # for pytorch>=1.9 + pred = self.detector(images)[0] + else: + with torch.no_grad(): # for pytorch<1.9 + pred = self.detector(images)[0] + + bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) + + # return bboxes, points + if not isListempty(points): + bboxes = np.array(bboxes).reshape(-1,4) + points = np.array(points).reshape(-1,10) + padding = bboxes[:,0].reshape(-1,1) + return np.concatenate((bboxes, padding, points), axis=1) + else: + return None + + def __call__(self, *args): + return self.predict(*args) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/common.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d088a2d647d52b3ea6130e6e24c5b620f0df2048 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/common.py @@ -0,0 +1,299 @@ +# This file contains modules common to various models + +import math + +import numpy as np +import torch +from torch import nn + +from r_facelib.detection.yolov5face.utils.datasets import letterbox +from r_facelib.detection.yolov5face.utils.general import ( + make_divisible, + non_max_suppression, + scale_coords, + xyxy2xywh, +) + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc") + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + return x.view(batchsize, -1, height, width) + + +def DWConv(c1, c2, k=1, s=1, act=True): + # Depthwise convolution + return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class StemBlock(nn.Module): + def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True): + super().__init__() + self.stem_1 = Conv(c1, c2, k, s, p, g, act) + self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0) + self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1) + self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) + self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0) + + def forward(self, x): + stem_1_out = self.stem_1(x) + stem_2a_out = self.stem_2a(stem_1_out) + stem_2b_out = self.stem_2b(stem_2a_out) + stem_2p_out = self.stem_2p(stem_1_out) + return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1)) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class ShuffleV2Block(nn.Module): + def __init__(self, inp, oup, stride): + super().__init__() + + if not 1 <= stride <= 3: + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(branch_features), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + out = channel_shuffle(out, 2) + return out + + +class SPP(nn.Module): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold + classes = None # (optional list) filter by class + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + +class AutoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super().__init__() + self.model = model.eval() + + def autoshape(self): + print("autoShape already enabled, skipping... ") # model already converted to model.autoshape() + return self + + def forward(self, imgs, size=640, augment=False, profile=False): + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images + shape0, shape1 = [], [] # image and inference shapes + for i, im in enumerate(imgs): + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = size / max(s) # gain + shape1.append([y * g for y in s]) + imgs[i] = im # update + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32 + + # Inference + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, names=None): + super().__init__() + d = pred[0].device # device + gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) + + def __len__(self): + return self.n + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + for d in x: + for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]: + setattr(d, k, getattr(d, k)[0]) # pop out of list + return x diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/experimental.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..1e425e620e4269354da1e0bc46ad9f18d2917b2b --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/experimental.py @@ -0,0 +1,45 @@ +# # This file contains experimental modules + +import numpy as np +import torch +from torch import nn + +from r_facelib.detection.yolov5face.models.common import Conv + + +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class MixConv2d(nn.Module): + # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + super().__init__() + groups = len(k) + if equal_ch: # equal c_ per group + i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + else: # equal weight.numel() per group + b = [c2] + [0] * groups + a = np.eye(groups + 1, groups, k=-1) + a -= np.roll(a, 1, axis=1) + a *= np.array(k) ** 2 + a[0] = 1 + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolo.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolo.py new file mode 100644 index 0000000000000000000000000000000000000000..409c3ab97d34ce0ef190138234fe12acfe38a514 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolo.py @@ -0,0 +1,235 @@ +import math +from copy import deepcopy +from pathlib import Path + +import torch +import yaml # for torch hub +from torch import nn + +from r_facelib.detection.yolov5face.models.common import ( + C3, + NMS, + SPP, + AutoShape, + Bottleneck, + BottleneckCSP, + Concat, + Conv, + DWConv, + Focus, + ShuffleV2Block, + StemBlock, +) +from r_facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d +from r_facelib.detection.yolov5face.utils.autoanchor import check_anchor_order +from r_facelib.detection.yolov5face.utils.general import make_divisible +from r_facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn + + +class Detect(nn.Module): + stride = None # strides computed during build + export = False # onnx export + + def __init__(self, nc=80, anchors=(), ch=()): # detection layer + super().__init__() + self.nc = nc # number of classes + self.no = nc + 5 + 10 # number of outputs per anchor + + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + a = torch.tensor(anchors).float().view(self.nl, -1, 2) + self.register_buffer("anchors", a) # shape(nl,na,2) + self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + + def forward(self, x): + z = [] # inference output + if self.export: + for i in range(self.nl): + x[i] = self.m[i](x[i]) + return x + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = torch.full_like(x[i], 0) + y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() + y[..., 5:15] = x[i][..., 5:15] + + y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + + y[..., 5:7] = ( + y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x1 y1 + y[..., 7:9] = ( + y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x2 y2 + y[..., 9:11] = ( + y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x3 y3 + y[..., 11:13] = ( + y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x4 y4 + y[..., 13:15] = ( + y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x5 y5 + + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + @staticmethod + def _make_grid(nx=20, ny=20): + # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10 + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + + +class Model(nn.Module): + def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes + super().__init__() + self.yaml_file = Path(cfg).name + with Path(cfg).open(encoding="utf8") as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + self.yaml["nc"] = nc # override yaml value + + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml["nc"])] # default names + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 128 # 2x min stride + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + + def forward(self, x): + return self.forward_once(x) # single-scale inference, train + + def forward_once(self, x): + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + return x + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print("Fusing layers... ") + for m in self.model.modules(): + if isinstance(m, Conv) and hasattr(m, "bn"): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.fuseforward # update forward + elif type(m) is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + return self + + def nms(self, mode=True): # add or remove NMS module + present = isinstance(self.model[-1], NMS) # last layer is NMS + if mode and not present: + print("Adding NMS... ") + m = NMS() # module + m.f = -1 # from + m.i = self.model[-1].i + 1 # index + self.model.add_module(name=str(m.i), module=m) # add + self.eval() + elif not mode and present: + print("Removing NMS... ") + self.model = self.model[:-1] # remove + return self + + def autoshape(self): # add autoShape module + print("Adding autoShape... ") + m = AutoShape(self) # wrap model + copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes + return m + + +def parse_model(d, ch): # model_dict, input_channels(3) + anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except: + pass + + n = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [ + Conv, + Bottleneck, + SPP, + DWConv, + MixConv2d, + Focus, + CrossConv, + BottleneckCSP, + C3, + ShuffleV2Block, + StemBlock, + ]: + c1, c2 = ch[f], args[0] + + c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3]: + args.insert(2, n) + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) + elif m is Detect: + args.append([ch[x + 1] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + ch.append(c2) + return nn.Sequential(*layers), sorted(save) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolov5l.yaml b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolov5l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0532b0e22fa7f59349b178146ffddcfdb368aba6 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolov5l.yaml @@ -0,0 +1,47 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 + [-1, 1, SPP, [1024, [3,5,7]]], + [-1, 3, C3, [1024, False]], # 8 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 5], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 12 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 3], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 16 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 13], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 19 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 22 (P5/32-large) + + [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolov5n.yaml b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolov5n.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caba6bed674aa2213b110f19e04eb352ffbeaf1e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/models/yolov5n.yaml @@ -0,0 +1,45 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 + [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 + [-1, 3, ShuffleV2Block, [128, 1]], # 2 + [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 + [-1, 7, ShuffleV2Block, [256, 1]], # 4 + [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 + [-1, 3, ShuffleV2Block, [512, 1]], # 6 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P4 + [-1, 1, C3, [128, False]], # 10 + + [-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], # cat backbone P3 + [-1, 1, C3, [128, False]], # 14 (P3/8-small) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 11], 1, Concat, [1]], # cat head P4 + [-1, 1, C3, [128, False]], # 17 (P4/16-medium) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 7], 1, Concat, [1]], # cat head P5 + [-1, 1, C3, [128, False]], # 20 (P5/32-large) + + [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/autoanchor.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/autoanchor.py new file mode 100644 index 0000000000000000000000000000000000000000..a4eba3e94888709be7d2a7c7499fbcc1808b4a88 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/autoanchor.py @@ -0,0 +1,12 @@ +# Auto-anchor utils + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print("Reversing anchor order") + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/datasets.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e672b136f56fd6b05038e24377908361a54fe519 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/datasets.py @@ -0,0 +1,35 @@ +import cv2 +import numpy as np + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): + # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding + elif scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/extract_ckpt.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/extract_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..ad427c9592365329d8451e859e3dc5fa90735050 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/extract_ckpt.py @@ -0,0 +1,5 @@ +import torch +import sys +sys.path.insert(0,'./facelib/detection/yolov5face') +model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] +torch.save(model.state_dict(),'../../models/facedetection') \ No newline at end of file diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/general.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8e14f56a107ec3a4269c382cfc5168ad780ffc --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/general.py @@ -0,0 +1,271 @@ +import math +import time + +import numpy as np +import torch +import torchvision + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + # if new_size != img_size: + # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") + return new_size + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) + + +def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label = labels[xi] + v = torch.zeros((len(label), nc + 15), device=x.device) + v[:, :4] = label[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + break # time limit exceeded + + return output + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label_id = labels[xi] + v = torch.zeros((len(label_id), nc + 5), device=x.device) + v[:, :4] = label_id[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f"WARNING: NMS time limit {time_limit}s exceeded") + break # time limit exceeded + + return output + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/torch_utils.py b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af2d06587b2d07b2eab199a8484380fde1de5c3c --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/detection/yolov5face/utils/torch_utils.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (include and k not in include) or k.startswith("_") or k in exclude: + continue + + setattr(a, k, v) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db827ff60ab8d7dd108a2c1e317cefe40e3fab85 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/__init__.py @@ -0,0 +1,23 @@ +import torch + +from r_facelib.utils import load_file_from_url +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' + elif model_name == 'parsenet': + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/bisenet.py b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/bisenet.py new file mode 100644 index 0000000000000000000000000000000000000000..3898cab76ae5876459cd4899c54cafa14234971d --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/bisenet.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) + out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) + out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/parsenet.py b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/parsenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e178ebe43a1ef666aaea0bc0faf629485c22a24f --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/parsenet.py @@ -0,0 +1,194 @@ +"""Modified from https://github.com/chaofengc/PSFRGAN +""" +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + + +class NormLayer(nn.Module): + """Normalization Layers. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. + """ + + def __init__(self, channels, normalize_shape=None, norm_type='bn'): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + if norm_type == 'bn': + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == 'in': + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == 'gn': + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == 'pixel': + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == 'layer': + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == 'none': + self.norm = lambda x: x * 1.0 + else: + assert 1 == 0, f'Norm type {norm_type} not support.' + + def forward(self, x, ref=None): + if self.norm_type == 'spade': + return self.norm(x, ref) + else: + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + + Args: + relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + + def __init__(self, channels, relu_type='relu'): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == 'relu': + self.func = nn.ReLU(True) + elif relu_type == 'leakyrelu': + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == 'prelu': + self.func = nn.PReLU(channels) + elif relu_type == 'selu': + self.func = nn.SELU(True) + elif relu_type == 'none': + self.func = lambda x: x * 1.0 + else: + assert 1 == 0, f'Relu type {relu_type} not support.' + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + scale='none', + norm_type='none', + relu_type='none', + use_pad=True, + bias=True): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + if norm_type in ['bn']: + bias = False + + stride = 2 if scale == 'down' else 1 + + self.scale_func = lambda x: x + if scale == 'up': + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): + super(ResidualBlock, self).__init__() + + if scale == 'none' and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + +class ParseNet(nn.Module): + + def __init__(self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type='LeakyReLU', + norm_type='bn', + ch_range=[32, 256]): + super().__init__() + self.res_depth = res_depth + act_args = {'norm_type': norm_type, 'relu_type': relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size // min_feat_size)) + up_steps = int(np.log2(out_size // min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/resnet.py b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fec8e82cf64469fb51be21ad5130217052addbda --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/parsing/resnet.py @@ -0,0 +1,69 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/utils/__init__.py b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f03b1c2bafcd7759cb7e8722a0c6715f201a46dc --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/__init__.py @@ -0,0 +1,7 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back +from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir + +__all__ = [ + 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', + 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' +] diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/utils/face_restoration_helper.py b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/face_restoration_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..5935125f42c0a7561cb1713b24235b2c25b332a0 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/face_restoration_helper.py @@ -0,0 +1,455 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from r_facelib.detection import init_detection_model +from r_facelib.parsing import init_parsing_model +from r_facelib.utils.misc import img2tensor, imwrite + + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = upscale_factor + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + + if self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + + # init face detection model + self.face_det = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + + if min(self.input_img.shape[:2])<512: + f = 512.0/min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_det.detect_faces(input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + + def add_restored_face(self, face): + self.restored_faces.append(face) + + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + + inv_mask_borders = [] + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # if draw_box or not self.use_parse: # use square parse maps + # mask = np.ones(face_size, dtype=np.float32) + # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # # remove the black borders + # inv_mask_erosion = cv2.erode( + # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + # pasted_face = inv_mask_erosion[:, :, None] * inv_restored + # total_face_area = np.sum(inv_mask_erosion) # // 3 + # # add border + # if draw_box: + # h, w = face_size + # mask_border = np.ones((h, w, 3), dtype=np.float32) + # border = int(1400/np.sqrt(total_face_area)) + # mask_border[border:h-border, border:w-border,:] = 0 + # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + # inv_mask_borders.append(inv_mask_border) + # if not self.use_parse: + # # compute the fusion edge based on the area of face + # w_edge = int(total_face_area**0.5) // 20 + # erosion_radius = w_edge * 2 + # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + # blur_size = w_edge * 2 + # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + # if len(upsample_img.shape) == 2: # upsample_img is gray image + # upsample_img = upsample_img[:, :, None] + # inv_soft_mask = inv_soft_mask[:, :, None] + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400/np.sqrt(total_face_area)) + mask_border[border:h-border, border:w-border,:] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255. + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:,:,0] = 0 + img_color[:,:,1] = 255 + img_color[:,:,2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/utils/face_utils.py b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/face_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e95207621849324ed75dfcb4c474860e9e603d3 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/face_utils.py @@ -0,0 +1,248 @@ +import cv2 +import numpy as np +import torch + + +def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): + left, top, right, bot = bbox + width = right - left + height = bot - top + + if preserve_aspect: + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + else: + width_increase = height_increase = increase_area + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + return (left, top, right, bot) + + +def get_valid_bboxes(bboxes, h, w): + left = max(bboxes[0], 0) + top = max(bboxes[1], 0) + right = min(bboxes[2], w) + bottom = min(bboxes[3], h) + return (left, top, right, bottom) + + +def align_crop_face_landmarks(img, + landmarks, + output_size, + transform_size=None, + enable_padding=True, + return_inverse_affine=False, + shrink_ratio=(1, 1)): + """Align and crop face with landmarks. + + The output_size and transform_size are based on width. The height is + adjusted based on shrink_ratio_h/shring_ration_w. + + Modified from: + https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + Args: + img (Numpy array): Input image. + landmarks (Numpy array): 5 or 68 or 98 landmarks. + output_size (int): Output face size. + transform_size (ing): Transform size. Usually the four time of + output_size. + enable_padding (float): Default: True. + shrink_ratio (float | tuple[float] | list[float]): Shring the whole + face for height and width (crop larger area). Default: (1, 1). + + Returns: + (Numpy array): Cropped face. + """ + lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 + + if isinstance(shrink_ratio, (float, int)): + shrink_ratio = (shrink_ratio, shrink_ratio) + if transform_size is None: + transform_size = output_size * 4 + + # Parse landmarks + lm = np.array(landmarks) + if lm.shape[0] == 5 and lm_type == 'retinaface_5': + eye_left = lm[0] + eye_right = lm[1] + mouth_avg = (lm[3] + lm[4]) * 0.5 + elif lm.shape[0] == 5 and lm_type == 'dlib_5': + lm_eye_left = lm[2:4] + lm_eye_right = lm[0:2] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = lm[4] + elif lm.shape[0] == 68: + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[48] + lm[54]) * 0.5 + elif lm.shape[0] == 98: + lm_eye_left = lm[60:68] + lm_eye_right = lm[68:76] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[76] + lm[82]) * 0.5 + + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1 # TODO: you can edit it to get larger rect + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + x *= shrink_ratio[1] # width + y *= shrink_ratio[0] # height + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + + quad_ori = np.copy(quad) + # Shrink, for large face + # TODO: do we really need shrink + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + h, w = img.shape[0:2] + rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) + img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) + quad /= shrink + qsize /= shrink + + # Crop + h, w = img.shape[0:2] + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) + if crop[2] - crop[0] < w or crop[3] - crop[1] < h: + img = img[crop[1]:crop[3], crop[0]:crop[2], :] + quad -= crop[0:2] + + # Pad + # pad: (width_left, height_top, width_right, height_bottom) + h, w = img.shape[0:2] + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w = img.shape[0:2] + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * 0.02) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) + + img = img.astype('float32') + img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = np.clip(img, 0, 255) # float32, [0, 255] + quad += pad[:2] + + # Transform use cv2 + h_ratio = shrink_ratio[0] / shrink_ratio[1] + dst_h, dst_w = int(transform_size * h_ratio), transform_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] + cropped_face = cv2.warpAffine( + img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray + + if output_size < transform_size: + cropped_face = cv2.resize( + cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) + + if return_inverse_affine: + dst_h, dst_w = int(output_size * h_ratio), output_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D( + quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] + inverse_affine = cv2.invertAffineTransform(affine_matrix) + else: + inverse_affine = None + return cropped_face, inverse_affine + + +def paste_face_back(img, face, inverse_affine): + h, w = img.shape[0:2] + face_h, face_w = face.shape[0:2] + inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) + mask = np.ones((face_h, face_w, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) + # remove the black borders + inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img + # float32, [0, 255] + return img + + +if __name__ == '__main__': + import os + + from custom_nodes.facerestore.facelib.detection import init_detection_model + from custom_nodes.facerestore.facelib.utils.face_restoration_helper import get_largest_face + + img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png' + img_name = os.splitext(os.path.basename(img_path))[0] + + # initialize model + det_net = init_detection_model('retinaface_resnet50', half=False) + img_ori = cv2.imread(img_path) + h, w = img_ori.shape[0:2] + # if larger than 800, scale it + scale = max(h / 800, w / 800) + if scale > 1: + img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) + + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + if scale > 1: + bboxes *= scale # the score is incorrect + bboxes = get_largest_face(bboxes, h, w)[0] + + landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) + + cropped_face, inverse_affine = align_crop_face_landmarks( + img_ori, + landmarks, + output_size=512, + transform_size=None, + enable_padding=True, + return_inverse_affine=True, + shrink_ratio=(1, 1)) + + cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face) + img = paste_face_back(img_ori, cropped_face, inverse_affine) + cv2.imwrite(f'tmp/{img_name}_back.png', img) diff --git a/custom_nodes/ComfyUI-ReActor/r_facelib/utils/misc.py b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..68b7fe9dbf19f2802d3789a2c740f0f03de44d61 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/r_facelib/utils/misc.py @@ -0,0 +1,143 @@ +import cv2 +import os +import os.path as osp +import torch +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse +# from basicsr.utils.download_util import download_file_from_google_drive +#import gdown + + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def download_pretrained_models(file_ids, save_path_root): + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = 'https://drive.google.com/uc?id='+file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + print("skipping gdown in facelib/utils/misc.py "+file_url) + #gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accepts Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + print("skipping gdown in facelib/utils/misc.py "+file_url) + #gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) diff --git a/custom_nodes/ComfyUI-ReActor/reactor_patcher.py b/custom_nodes/ComfyUI-ReActor/reactor_patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..efa3c9eb6ffb5d153d0f1df7d61dbdac63dc0f76 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/reactor_patcher.py @@ -0,0 +1,161 @@ +import os.path as osp +import glob +import logging +import insightface +from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession +from insightface.model_zoo.retinaface import RetinaFace +from insightface.model_zoo.landmark import Landmark +from insightface.model_zoo.attribute import Attribute +from insightface.model_zoo.inswapper import INSwapper +from insightface.model_zoo.arcface_onnx import ArcFaceONNX +from insightface.app import FaceAnalysis +from insightface.utils import DEFAULT_MP_NAME, ensure_available +from insightface.model_zoo import model_zoo +import onnxruntime +import onnx +from onnx import numpy_helper +from scripts.reactor_logger import logger + + +def patched_get_model_log(self, **kwargs): + session = PickableInferenceSession(self.onnx_file, **kwargs) + print(f'Applied providers: {session._providers}, with options: {session._provider_options}') + inputs = session.get_inputs() + input_cfg = inputs[0] + input_shape = input_cfg.shape + outputs = session.get_outputs() + + if len(outputs) >= 5: + return RetinaFace(model_file=self.onnx_file, session=session) + elif input_shape[2] == 192 and input_shape[3] == 192: + return Landmark(model_file=self.onnx_file, session=session) + elif input_shape[2] == 96 and input_shape[3] == 96: + return Attribute(model_file=self.onnx_file, session=session) + elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128: + return INSwapper(model_file=self.onnx_file, session=session) + elif len(inputs) == 2 and input_shape[2] == 256 and input_shape[3] == 256: + return INSwapper(model_file=self.onnx_file, session=session) + elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0: + return ArcFaceONNX(model_file=self.onnx_file, session=session) + else: + return None + +def patched_get_model(self, **kwargs): + session = PickableInferenceSession(self.onnx_file, **kwargs) + inputs = session.get_inputs() + input_cfg = inputs[0] + input_shape = input_cfg.shape + outputs = session.get_outputs() + + if len(outputs) >= 5: + return RetinaFace(model_file=self.onnx_file, session=session) + elif input_shape[2] == 192 and input_shape[3] == 192: + return Landmark(model_file=self.onnx_file, session=session) + elif input_shape[2] == 96 and input_shape[3] == 96: + return Attribute(model_file=self.onnx_file, session=session) + elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128: + return INSwapper(model_file=self.onnx_file, session=session) + elif len(inputs) == 2 and input_shape[2] == 256 and input_shape[3] == 256: + return INSwapper(model_file=self.onnx_file, session=session) + elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0: + return ArcFaceONNX(model_file=self.onnx_file, session=session) + else: + return None + + +def patched_faceanalysis_init(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): + onnxruntime.set_default_logger_severity(3) + self.models = {} + self.model_dir = ensure_available('models', name, root=root) + onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) + onnx_files = sorted(onnx_files) + for onnx_file in onnx_files: + model = model_zoo.get_model(onnx_file, **kwargs) + if model is None: + print('model not recognized:', onnx_file) + elif allowed_modules is not None and model.taskname not in allowed_modules: + print('model ignore:', onnx_file, model.taskname) + del model + elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): + self.models[model.taskname] = model + else: + print('duplicated model task type, ignore:', onnx_file, model.taskname) + del model + assert 'detection' in self.models + self.det_model = self.models['detection'] + + +def patched_faceanalysis_prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): + self.det_thresh = det_thresh + assert det_size is not None + self.det_size = det_size + for taskname, model in self.models.items(): + if taskname == 'detection': + model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) + else: + model.prepare(ctx_id) + + +def patched_inswapper_init(self, model_file=None, session=None): + self.model_file = model_file + self.session = session + model = onnx.load(self.model_file) + graph = model.graph + self.emap = numpy_helper.to_array(graph.initializer[-1]) + self.input_mean = 0.0 + self.input_std = 255.0 + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + inputs = self.session.get_inputs() + self.input_names = [] + for inp in inputs: + self.input_names.append(inp.name) + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.output_names = output_names + assert len(self.output_names) == 1 + input_cfg = inputs[0] + input_shape = input_cfg.shape + self.input_shape = input_shape + self.input_size = tuple(input_shape[2:4][::-1]) + + +def pathced_retinaface_prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None and self.input_size is None: + self.input_size = input_size + + +def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init, retinaface_prepare): + insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model + insightface.app.FaceAnalysis.__init__ = faceanalysis_init + insightface.app.FaceAnalysis.prepare = faceanalysis_prepare + insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init + insightface.model_zoo.retinaface.RetinaFace.prepare = retinaface_prepare + + +# original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__, RetinaFace.prepare] +original_functions = [patched_get_model_log, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__, RetinaFace.prepare] +patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init, pathced_retinaface_prepare] + + +def apply_patch(console_log_level): + if console_log_level == 0: + patch_insightface(*patched_functions) + logger.setLevel(logging.WARNING) + elif console_log_level == 1: + patch_insightface(*patched_functions) + logger.setLevel(logging.STATUS) + elif console_log_level == 2: + patch_insightface(*original_functions) + logger.setLevel(logging.INFO) diff --git a/custom_nodes/ComfyUI-ReActor/reactor_utils.py b/custom_nodes/ComfyUI-ReActor/reactor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af79497eca14ac8ac576651edb13d896c6b5f494 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/reactor_utils.py @@ -0,0 +1,231 @@ +import os +from PIL import Image +import numpy as np +import torch +from torchvision.utils import make_grid +import cv2 +import math +import logging +import hashlib +from insightface.app.common import Face +from safetensors.torch import save_file, safe_open +from tqdm import tqdm +import urllib.request +import onnxruntime +from typing import Any +import folder_paths + +ORT_SESSION = None + +def tensor_to_pil(img_tensor, batch_index=0): + # Convert tensor of shape [batch_size, channels, height, width] at the batch_index to PIL Image + img_tensor = img_tensor[batch_index].unsqueeze(0) + i = 255. * img_tensor.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze()) + return img + + +def batch_tensor_to_pil(img_tensor): + # Convert tensor of shape [batch_size, channels, height, width] to a list of PIL Images + return [tensor_to_pil(img_tensor, i) for i in range(img_tensor.shape[0])] + + +def pil_to_tensor(image): + # Takes a PIL image and returns a tensor of shape [1, height, width, channels] + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image).unsqueeze(0) + if len(image.shape) == 3: # If the image is grayscale, add a channel dimension + image = image.unsqueeze(-1) + return image + + +def batched_pil_to_tensor(images): + # Takes a list of PIL images and returns a tensor of shape [batch_size, height, width, channels] + return torch.cat([pil_to_tensor(image) for image in images], dim=0) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def rgba2rgb_tensor(rgba): + r = rgba[...,0] + g = rgba[...,1] + b = rgba[...,2] + return torch.stack([r, g, b], dim=3) + + +def download(url, path, name): + request = urllib.request.urlopen(url) + total = int(request.headers.get('Content-Length', 0)) + with tqdm(total=total, desc=f'[ReActor] Downloading {name} to {path}', unit='B', unit_scale=True, unit_divisor=1024) as progress: + urllib.request.urlretrieve(url, path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) + + +def move_path(old_path, new_path): + if os.path.exists(old_path): + try: + models = os.listdir(old_path) + for model in models: + move_old_path = os.path.join(old_path, model) + move_new_path = os.path.join(new_path, model) + os.rename(move_old_path, move_new_path) + os.rmdir(old_path) + except Exception as e: + print(f"Error: {e}") + new_path = old_path + + +def addLoggingLevel(levelName, levelNum, methodName=None): + if not methodName: + methodName = levelName.lower() + + def logForLevel(self, message, *args, **kwargs): + if self.isEnabledFor(levelNum): + self._log(levelNum, message, args, **kwargs) + + def logToRoot(message, *args, **kwargs): + logging.log(levelNum, message, *args, **kwargs) + + logging.addLevelName(levelNum, levelName) + setattr(logging, levelName, levelNum) + setattr(logging.getLoggerClass(), methodName, logForLevel) + setattr(logging, methodName, logToRoot) + + +def get_image_md5hash(image: Image.Image): + md5hash = hashlib.md5(image.tobytes()) + return md5hash.hexdigest() + + +def save_face_model(face: Face, filename: str) -> None: + try: + tensors = { + "bbox": torch.tensor(face["bbox"]), + "kps": torch.tensor(face["kps"]), + "det_score": torch.tensor(face["det_score"]), + "landmark_3d_68": torch.tensor(face["landmark_3d_68"]), + "pose": torch.tensor(face["pose"]), + "landmark_2d_106": torch.tensor(face["landmark_2d_106"]), + "embedding": torch.tensor(face["embedding"]), + "gender": torch.tensor(face["gender"]), + "age": torch.tensor(face["age"]), + } + save_file(tensors, filename) + print(f"Face model has been saved to '{filename}'") + except Exception as e: + print(f"Error: {e}") + + +def load_face_model(filename: str): + face = {} + with safe_open(filename, framework="pt") as f: + for k in f.keys(): + face[k] = f.get_tensor(k).numpy() + return Face(face) + + +def get_ort_session(): + global ORT_SESSION + return ORT_SESSION + +def set_ort_session(model_path, providers) -> Any: + global ORT_SESSION + onnxruntime.set_default_logger_severity(3) + ORT_SESSION = onnxruntime.InferenceSession(model_path, providers=providers) + return ORT_SESSION + +def clear_ort_session() -> None: + global ORT_SESSION + ORT_SESSION = None + +def prepare_cropped_face(cropped_face): + cropped_face = cropped_face[:, :, ::-1] / 255.0 + cropped_face = (cropped_face - 0.5) / 0.5 + cropped_face = np.expand_dims(cropped_face.transpose(2, 0, 1), axis = 0).astype(np.float32) + return cropped_face + +def normalize_cropped_face(cropped_face): + cropped_face = np.clip(cropped_face, -1, 1) + cropped_face = (cropped_face + 1) / 2 + cropped_face = cropped_face.transpose(1, 2, 0) + cropped_face = (cropped_face * 255.0).round() + cropped_face = cropped_face.astype(np.uint8)[:, :, ::-1] + return cropped_face + + +# author: Trung0246 ---> +def add_folder_path_and_extensions(folder_name, full_folder_paths, extensions): + # Iterate over the list of full folder paths + for full_folder_path in full_folder_paths: + # Use the provided function to add each model folder path + folder_paths.add_model_folder_path(folder_name, full_folder_path) + + # Now handle the extensions. If the folder name already exists, update the extensions + if folder_name in folder_paths.folder_names_and_paths: + # Unpack the current paths and extensions + current_paths, current_extensions = folder_paths.folder_names_and_paths[folder_name] + # Update the extensions set with the new extensions + updated_extensions = current_extensions | extensions + # Reassign the updated tuple back to the dictionary + folder_paths.folder_names_and_paths[folder_name] = (current_paths, updated_extensions) + else: + # If the folder name was not present, add_model_folder_path would have added it with the last path + # Now we just need to update the set of extensions as it would be an empty set + # Also ensure that all paths are included (since add_model_folder_path adds only one path at a time) + folder_paths.folder_names_and_paths[folder_name] = (full_folder_paths, extensions) +# <--- diff --git a/custom_nodes/ComfyUI-ReActor/requirements.txt b/custom_nodes/ComfyUI-ReActor/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..dedcf30a44df22e39a86df139d3ea2331669c8ac --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/requirements.txt @@ -0,0 +1,7 @@ +albumentations>=1.4.16 +insightface==0.7.3 +onnx>=1.14.0 +opencv-python>=4.7.0.72 +numpy==1.26.3 +segment_anything +ultralytics diff --git a/custom_nodes/ComfyUI-ReActor/scripts/__init__.py b/custom_nodes/ComfyUI-ReActor/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_archs/__init__.py b/custom_nodes/ComfyUI-ReActor/scripts/r_archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_archs/codeformer_arch.py b/custom_nodes/ComfyUI-ReActor/scripts/r_archs/codeformer_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3b09ebaacc0f2d4a92126d40c2ac7650151ff602 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_archs/codeformer_arch.py @@ -0,0 +1,278 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from scripts.r_archs.vqgan_arch import * +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import ARCH_REGISTRY + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator']): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat + \ No newline at end of file diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_archs/vqgan_arch.py b/custom_nodes/ComfyUI-ReActor/scripts/r_archs/vqgan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3925a27df4816144e9d29deae50164b18b1cb3 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_archs/vqgan_arch.py @@ -0,0 +1,437 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import ARCH_REGISTRY + + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "min_encoding_scores": min_encoding_scores, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) + \ No newline at end of file diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/__init__.py b/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/restorer.py b/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/restorer.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62977cd0a97f4e199644f779b2a8c3047ca92e --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/restorer.py @@ -0,0 +1,130 @@ +import sys +import cv2 +import numpy as np +import torch +from torchvision.transforms.functional import normalize + +try: + import torch.cuda as cuda +except: + cuda = None + +import comfy.utils +import folder_paths +import comfy.model_management as model_management + +from scripts.reactor_logger import logger +from r_basicsr.utils.registry import ARCH_REGISTRY +from r_chainner import model_loading +from reactor_utils import ( + tensor2img, + img2tensor, + set_ort_session, + prepare_cropped_face, + normalize_cropped_face +) + + +if cuda is not None: + if cuda.is_available(): + providers = ["CUDAExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] +else: + providers = ["CPUExecutionProvider"] + + +def get_restored_face(cropped_face, + face_restore_model, + face_restore_visibility, + codeformer_weight, + interpolation: str = "Bicubic"): + + if interpolation == "Bicubic": + interpolate = cv2.INTER_CUBIC + elif interpolation == "Bilinear": + interpolate = cv2.INTER_LINEAR + elif interpolation == "Nearest": + interpolate = cv2.INTER_NEAREST + elif interpolation == "Lanczos": + interpolate = cv2.INTER_LANCZOS4 + + face_size = 512 + if "1024" in face_restore_model.lower(): + face_size = 1024 + elif "2048" in face_restore_model.lower(): + face_size = 2048 + + scale = face_size / cropped_face.shape[0] + + logger.status(f"Boosting the Face with {face_restore_model} | Face Size is set to {face_size} with Scale Factor = {scale} and '{interpolation}' interpolation") + + cropped_face = cv2.resize(cropped_face, (face_size, face_size), interpolation=interpolate) + + # For upscaling the base 128px face, I found bicubic interpolation to be the best compromise targeting antialiasing + # and detail preservation. Nearest is predictably unusable, Linear produces too much aliasing, and Lanczos produces + # too many hallucinations and artifacts/fringing. + + model_path = folder_paths.get_full_path("facerestore_models", face_restore_model) + device = model_management.get_torch_device() + + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + + with torch.no_grad(): + + if ".onnx" in face_restore_model: # ONNX models + + ort_session = set_ort_session(model_path, providers=providers) + ort_session_inputs = {} + facerestore_model = ort_session + + for ort_session_input in ort_session.get_inputs(): + if ort_session_input.name == "input": + cropped_face_prep = prepare_cropped_face(cropped_face) + ort_session_inputs[ort_session_input.name] = cropped_face_prep + if ort_session_input.name == "weight": + weight = np.array([1], dtype=np.double) + ort_session_inputs[ort_session_input.name] = weight + + output = ort_session.run(None, ort_session_inputs)[0][0] + restored_face = normalize_cropped_face(output) + + else: # PTH models + + if "codeformer" in face_restore_model.lower(): + codeformer_net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device) + checkpoint = torch.load(model_path)["params_ema"] + codeformer_net.load_state_dict(checkpoint) + facerestore_model = codeformer_net.eval() + else: + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + facerestore_model = model_loading.load_state_dict(sd).eval() + facerestore_model.to(device) + + output = facerestore_model(cropped_face_t, w=codeformer_weight)[ + 0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + + del output + torch.cuda.empty_cache() + + except Exception as error: + + print(f"\tFailed inference: {error}", file=sys.stderr) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + if face_restore_visibility < 1: + restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility + + restored_face = restored_face.astype("uint8") + return restored_face, scale diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/swapper.py b/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cfb9652fca16c7eed8d1a8308bffefc6b8bc89 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_faceboost/swapper.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np + +# The following code is almost entirely copied from INSwapper; the only change here is that we want to use Lanczos +# interpolation for the warpAffine call. Now that the face has been restored, Lanczos represents a good compromise +# whether the restored face needs to be upscaled or downscaled. +def in_swap(img, bgr_fake, M): + target_img = img + IM = cv2.invertAffineTransform(M) + img_white = np.full((bgr_fake.shape[0], bgr_fake.shape[1]), 255, dtype=np.float32) + + # Note the use of bicubic here; this is functionally the only change from the source code + bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0, flags=cv2.INTER_CUBIC) + + img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white[img_white > 20] = 255 + img_mask = img_white + mask_h_inds, mask_w_inds = np.where(img_mask == 255) + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h * mask_w)) + k = max(mask_size // 10, 10) + # k = max(mask_size//20, 6) + # k = 6 + kernel = np.ones((k, k), np.uint8) + img_mask = cv2.erode(img_mask, kernel, iterations=1) + kernel = np.ones((2, 2), np.uint8) + k = max(mask_size // 20, 5) + # k = 3 + # k = 3 + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + k = 5 + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + img_mask /= 255 + # img_mask = fake_diff + img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1]) + fake_merged = img_mask * bgr_fake + (1 - img_mask) * target_img.astype(np.float32) + fake_merged = fake_merged.astype(np.uint8) + return fake_merged diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_masking/__init__.py b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_masking/core.py b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/core.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9e25e2c73c187aad33dd190465b92d43c97b97 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/core.py @@ -0,0 +1,647 @@ +import numpy as np +import cv2 +import torch +import torchvision.transforms.functional as TF + +import sys as _sys +from keyword import iskeyword as _iskeyword +from operator import itemgetter as _itemgetter + +from segment_anything import SamPredictor + +from comfy import model_management + + +################################################################################ +### namedtuple +################################################################################ + +try: + from _collections import _tuplegetter +except ImportError: + _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) + +def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessible by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Validate the field names. At the user's option, either generate an error + # message or automatically replace the field name with a valid name. + if isinstance(field_names, str): + field_names = field_names.replace(',', ' ').split() + field_names = list(map(str, field_names)) + typename = _sys.intern(str(typename)) + + if rename: + seen = set() + for index, name in enumerate(field_names): + if (not name.isidentifier() + or _iskeyword(name) + or name.startswith('_') + or name in seen): + field_names[index] = f'_{index}' + seen.add(name) + + for name in [typename] + field_names: + if type(name) is not str: + raise TypeError('Type names and field names must be strings') + if not name.isidentifier(): + raise ValueError('Type names and field names must be valid ' + f'identifiers: {name!r}') + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a ' + f'keyword: {name!r}') + + seen = set() + for name in field_names: + if name.startswith('_') and not rename: + raise ValueError('Field names cannot start with an underscore: ' + f'{name!r}') + if name in seen: + raise ValueError(f'Encountered duplicate field name: {name!r}') + seen.add(name) + + field_defaults = {} + if defaults is not None: + defaults = tuple(defaults) + if len(defaults) > len(field_names): + raise TypeError('Got more default values than field names') + field_defaults = dict(reversed(list(zip(reversed(field_names), + reversed(defaults))))) + + # Variables used in the methods and docstrings + field_names = tuple(map(_sys.intern, field_names)) + num_fields = len(field_names) + arg_list = ', '.join(field_names) + if num_fields == 1: + arg_list += ',' + repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' + tuple_new = tuple.__new__ + _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip + + # Create all the named tuple methods to be added to the class namespace + + namespace = { + '_tuple_new': tuple_new, + '__builtins__': {}, + '__name__': f'namedtuple_{typename}', + } + code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' + __new__ = eval(code, namespace) + __new__.__name__ = '__new__' + __new__.__doc__ = f'Create new instance of {typename}({arg_list})' + if defaults is not None: + __new__.__defaults__ = defaults + + @classmethod + def _make(cls, iterable): + result = tuple_new(cls, iterable) + if _len(result) != num_fields: + raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') + return result + + _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' + 'or iterable') + + def _replace(self, /, **kwds): + result = self._make(_map(kwds.pop, field_names, self)) + if kwds: + raise ValueError(f'Got unexpected field names: {list(kwds)!r}') + return result + + _replace.__doc__ = (f'Return a new {typename} object replacing specified ' + 'fields with new values') + + def __repr__(self): + 'Return a nicely formatted representation string' + return self.__class__.__name__ + repr_fmt % self + + def _asdict(self): + 'Return a new dict which maps field names to their values.' + return _dict(_zip(self._fields, self)) + + def __getnewargs__(self): + 'Return self as a plain tuple. Used by copy and pickle.' + return _tuple(self) + + # Modify function metadata to help with introspection and debugging + for method in ( + __new__, + _make.__func__, + _replace, + __repr__, + _asdict, + __getnewargs__, + ): + method.__qualname__ = f'{typename}.{method.__name__}' + + # Build-up the class namespace dictionary + # and use type() to build the result class + class_namespace = { + '__doc__': f'{typename}({arg_list})', + '__slots__': (), + '_fields': field_names, + '_field_defaults': field_defaults, + '__new__': __new__, + '_make': _make, + '_replace': _replace, + '__repr__': __repr__, + '_asdict': _asdict, + '__getnewargs__': __getnewargs__, + '__match_args__': field_names, + } + for index, name in enumerate(field_names): + doc = _sys.intern(f'Alias for field number {index}') + class_namespace[name] = _tuplegetter(index, doc) + + result = type(typename, (tuple,), class_namespace) + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython), or where the user has + # specified a particular module. + if module is None: + try: + module = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + if module is not None: + result.__module__ = module + + return result + + +SEG = namedtuple("SEG", + ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], + defaults=[None]) + +def crop_ndarray4(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[:, y1:y2, x1:x2, :] + + return cropped + +crop_tensor4 = crop_ndarray4 + +def crop_ndarray2(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[y1:y2, x1:x2] + + return cropped + +def crop_image(image, crop_region): + return crop_tensor4(image, crop_region) + +def normalize_region(limit, startp, size): + if startp < 0: + new_endp = min(limit, size) + new_startp = 0 + elif startp + size > limit: + new_startp = max(0, limit - size) + new_endp = limit + else: + new_startp = startp + new_endp = min(limit, startp+size) + + return int(new_startp), int(new_endp) + +def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): + x1 = bbox[0] + y1 = bbox[1] + x2 = bbox[2] + y2 = bbox[3] + + bbox_w = x2 - x1 + bbox_h = y2 - y1 + + crop_w = bbox_w * crop_factor + crop_h = bbox_h * crop_factor + + if crop_min_size is not None: + crop_w = max(crop_min_size, crop_w) + crop_h = max(crop_min_size, crop_h) + + kernel_x = x1 + bbox_w / 2 + kernel_y = y1 + bbox_h / 2 + + new_x1 = int(kernel_x - crop_w / 2) + new_y1 = int(kernel_y - crop_h / 2) + + # make sure position in (w,h) + new_x1, new_x2 = normalize_region(w, new_x1, crop_w) + new_y1, new_y2 = normalize_region(h, new_y1, crop_h) + + return [new_x1, new_y1, new_x2, new_y2] + +def create_segmasks(results): + bboxs = results[1] + segms = results[2] + confidence = results[3] + + results = [] + for i in range(len(segms)): + item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) + results.append(item) + return results + +def dilate_masks(segmasks, dilation_factor, iter=1): + if dilation_factor == 0: + return segmasks + + dilated_masks = [] + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + kernel = cv2.UMat(kernel) + + for i in range(len(segmasks)): + cv2_mask = segmasks[i][1] + + cv2_mask = cv2.UMat(cv2_mask) + + if dilation_factor > 0: + dilated_mask = cv2.dilate(cv2_mask, kernel, iter) + else: + dilated_mask = cv2.erode(cv2_mask, kernel, iter) + + dilated_mask = dilated_mask.get() + + item = (segmasks[i][0], dilated_mask, segmasks[i][2]) + dilated_masks.append(item) + + return dilated_masks + +def is_same_device(a, b): + a_device = torch.device(a) if isinstance(a, str) else a + b_device = torch.device(b) if isinstance(b, str) else b + return a_device.type == b_device.type and a_device.index == b_device.index + +class SafeToGPU: + def __init__(self, size): + self.size = size + + def to_device(self, obj, device): + if is_same_device(device, 'cpu'): + obj.to(device) + else: + if is_same_device(obj.device, 'cpu'): # cpu to gpu + model_management.free_memory(self.size * 1.3, device) + if model_management.get_free_memory(device) > self.size * 1.3: + try: + obj.to(device) + except: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]") + else: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]") + +def center_of_bbox(bbox): + w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] + return bbox[0] + w/2, bbox[1] + h/2 + +def sam_predict(predictor, points, plabs, bbox, threshold): + point_coords = None if not points else np.array(points) + point_labels = None if not plabs else np.array(plabs) + + box = np.array([bbox]) if bbox is not None else None + + cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) + + total_masks = [] + + selected = False + max_score = 0 + max_mask = None + for idx in range(len(scores)): + if scores[idx] > max_score: + max_score = scores[idx] + max_mask = cur_masks[idx] + + if scores[idx] >= threshold: + selected = True + total_masks.append(cur_masks[idx]) + else: + pass + + if not selected and max_mask is not None: + total_masks.append(max_mask) + + return total_masks + +def make_2d_mask(mask): + if len(mask.shape) == 4: + return mask.squeeze(0).squeeze(0) + + elif len(mask.shape) == 3: + return mask.squeeze(0) + + return mask + +def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): + mask = make_2d_mask(mask) + + points = [] + plabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(mask.shape[0] / 20)) + x_step = max(3, int(mask.shape[1] / 20)) + + for i in range(0, len(mask), y_step): + for j in range(0, len(mask[i]), x_step): + if mask[i][j] > threshold: + points.append((x + j, y + i)) + plabs.append(1) + elif use_negative and mask[i][j] == 0: + points.append((x + j, y + i)) + plabs.append(0) + + return points, plabs + +def gen_negative_hints(w, h, x1, y1, x2, y2): + npoints = [] + nplabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(w / 20)) + x_step = max(3, int(h / 20)) + + for i in range(10, h - 10, y_step): + for j in range(10, w - 10, x_step): + if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): + npoints.append((j, i)) + nplabs.append(0) + + return npoints, nplabs + +def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, + mask_hint_use_negative): + [x1, y1, x2, y2] = dilated_bbox + + points = [] + plabs = [] + if detection_hint == "center-1": + points.append(center) + plabs = [1] # 1 = foreground point, 0 = background point + + elif detection_hint == "horizontal-2": + gap = (x2 - x1) / 3 + points.append((x1 + gap, center[1])) + points.append((x1 + gap * 2, center[1])) + plabs = [1, 1] + + elif detection_hint == "vertical-2": + gap = (y2 - y1) / 3 + points.append((center[0], y1 + gap)) + points.append((center[0], y1 + gap * 2)) + plabs = [1, 1] + + elif detection_hint == "rect-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, center[1])) + points.append((x1 + x_gap * 2, center[1])) + points.append((center[0], y1 + y_gap)) + points.append((center[0], y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "diamond-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, y1 + y_gap)) + points.append((x1 + x_gap * 2, y1 + y_gap)) + points.append((x1 + x_gap, y1 + y_gap * 2)) + points.append((x1 + x_gap * 2, y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "mask-point-bbox": + center = center_of_bbox(seg.bbox) + points.append(center) + plabs = [1] + + elif detection_hint == "mask-area": + points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], + seg.cropped_mask, + mask_hint_threshold, use_small_negative) + + if mask_hint_use_negative == "Outter": + npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], + seg.crop_region[0], seg.crop_region[1], + seg.crop_region[2], seg.crop_region[3]) + + points += npoints + plabs += nplabs + + return points, plabs + +def combine_masks2(masks): + if len(masks) == 0: + return None + else: + initial_cv2_mask = np.array(masks[0]).astype(np.uint8) + combined_cv2_mask = initial_cv2_mask + + for i in range(1, len(masks)): + cv2_mask = np.array(masks[i]).astype(np.uint8) + + if combined_cv2_mask.shape == cv2_mask.shape: + combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) + else: + # do nothing - incompatible mask + pass + + mask = torch.from_numpy(combined_cv2_mask) + return mask + +def dilate_mask(mask, dilation_factor, iter=1): + if dilation_factor == 0: + return make_2d_mask(mask) + + mask = make_2d_mask(mask) + + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + mask = cv2.UMat(mask) + kernel = cv2.UMat(kernel) + + if dilation_factor > 0: + result = cv2.dilate(mask, kernel, iter) + else: + result = cv2.erode(mask, kernel, iter) + + return result.get() + +def convert_and_stack_masks(masks): + if len(masks) == 0: + return None + + mask_tensors = [] + for mask in masks: + mask_array = np.array(mask, dtype=np.uint8) + mask_tensor = torch.from_numpy(mask_array) + mask_tensors.append(mask_tensor) + + stacked_masks = torch.stack(mask_tensors, dim=0) + stacked_masks = stacked_masks.unsqueeze(1) + + return stacked_masks + +def merge_and_stack_masks(stacked_masks, group_size): + if stacked_masks is None: + return None + + num_masks = stacked_masks.size(0) + merged_masks = [] + + for i in range(0, num_masks, group_size): + subset_masks = stacked_masks[i:i + group_size] + merged_mask = torch.any(subset_masks, dim=0) + merged_masks.append(merged_mask) + + if len(merged_masks) > 0: + merged_masks = torch.stack(merged_masks, dim=0) + + return merged_masks + +def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, + threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): + if sam_model.is_auto_mode: + device = model_management.get_torch_device() + sam_model.safe_to.to_device(sam_model, device=device) + + try: + predictor = SamPredictor(sam_model) + image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + predictor.set_image(image, "RGB") + + total_masks = [] + + use_small_negative = mask_hint_use_negative == "Small" + + # seg_shape = segs[0] + segs = segs[1] + if detection_hint == "mask-points": + points = [] + plabs = [] + + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + points.append(center) + + # small point is background, big point is foreground + if use_small_negative and bbox[2] - bbox[0] < 10: + plabs.append(0) + else: + plabs.append(1) + + detected_masks = sam_predict(predictor, points, plabs, None, threshold) + total_masks += detected_masks + + else: + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + x1 = max(bbox[0] - bbox_expansion, 0) + y1 = max(bbox[1] - bbox_expansion, 0) + x2 = min(bbox[2] + bbox_expansion, image.shape[1]) + y2 = min(bbox[3] + bbox_expansion, image.shape[0]) + + dilated_bbox = [x1, y1, x2, y2] + + points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, + mask_hint_threshold, use_small_negative, + mask_hint_use_negative) + + detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) + + total_masks += detected_masks + + # merge every collected masks + mask = combine_masks2(total_masks) + + finally: + if sam_model.is_auto_mode: + sam_model.cpu() + + pass + + mask_working_device = torch.device("cpu") + + if mask is not None: + mask = mask.float() + mask = dilate_mask(mask.cpu().numpy(), dilation) + mask = torch.from_numpy(mask) + mask = mask.to(device=mask_working_device) + else: + # Extracting batch, height and width + height, width, _ = image.shape + mask = torch.zeros( + (height, width), dtype=torch.float32, device=mask_working_device + ) # empty mask + + stacked_masks = convert_and_stack_masks(total_masks) + + return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) + +def tensor2mask(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t + if size[3] == 1: + return t[:,:,:,0] + elif size[3] == 4: + # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior + if torch.min(t[:, :, :, 3]).item() != 1.: + return t[:,:,:,3] + return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] + +def tensor2rgb(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 3) + if size[3] == 1: + return t.repeat(1, 1, 1, 3) + elif size[3] == 4: + return t[:, :, :, :3] + else: + return t + +def tensor2rgba(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 4) + elif size[3] == 1: + return t.repeat(1, 1, 1, 4) + elif size[3] == 3: + alpha_tensor = torch.ones((size[0], size[1], size[2], 1)) + return torch.cat((t, alpha_tensor), dim=3) + else: + return t diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_masking/segs.py b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/segs.py new file mode 100644 index 0000000000000000000000000000000000000000..cd84054dc34c50189aebc75a7661587aae01617a --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/segs.py @@ -0,0 +1,22 @@ +def filter(segs, labels): + labels = set([label.strip() for label in labels]) + + if 'all' in labels: + return (segs, (segs[0], []), ) + else: + res_segs = [] + remained_segs = [] + + for x in segs[1]: + if x.label in labels: + res_segs.append(x) + elif 'eyes' in labels and x.label in ['left_eye', 'right_eye']: + res_segs.append(x) + elif 'eyebrows' in labels and x.label in ['left_eyebrow', 'right_eyebrow']: + res_segs.append(x) + elif 'pupils' in labels and x.label in ['left_pupil', 'right_pupil']: + res_segs.append(x) + else: + remained_segs.append(x) + + return ((segs[0], res_segs), (segs[0], remained_segs), ) diff --git a/custom_nodes/ComfyUI-ReActor/scripts/r_masking/subcore.py b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/subcore.py new file mode 100644 index 0000000000000000000000000000000000000000..bc47c8591c6496efaa0018eb9749bc198b366557 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/r_masking/subcore.py @@ -0,0 +1,117 @@ +import numpy as np +import cv2 +from PIL import Image + +import scripts.r_masking.core as core +from reactor_utils import tensor_to_pil + +try: + from ultralytics import YOLO +except Exception as e: + print(e) + + +def load_yolo(model_path: str): + try: + return YOLO(model_path) + except ModuleNotFoundError: + # https://github.com/ultralytics/ultralytics/issues/3856 + YOLO("yolov8n.pt") + return YOLO(model_path) + +def inference_bbox( + model, + image: Image.Image, + confidence: float = 0.3, + device: str = "", +): + pred = model(image, conf=confidence, device=device) + + bboxes = pred[0].boxes.xyxy.cpu().numpy() + cv2_image = np.array(image) + if len(cv2_image.shape) == 3: + cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing + else: + # Handle the grayscale image here + # For example, you might want to convert it to a 3-channel grayscale image for consistency: + cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR) + cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) + + segms = [] + for x0, y0, x1, y1 in bboxes: + cv2_mask = np.zeros(cv2_gray.shape, np.uint8) + cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) + cv2_mask_bool = cv2_mask.astype(bool) + segms.append(cv2_mask_bool) + + n, m = bboxes.shape + if n == 0: + return [[], [], [], []] + + results = [[], [], [], []] + for i in range(len(bboxes)): + results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) + results[1].append(bboxes[i]) + results[2].append(segms[i]) + results[3].append(pred[0].boxes[i].conf.cpu().numpy()) + + return results + + +class UltraBBoxDetector: + bbox_model = None + + def __init__(self, bbox_model): + self.bbox_model = bbox_model + + def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): + drop_size = max(drop_size, 1) + detected_results = inference_bbox(self.bbox_model, tensor_to_pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + items = [] + h = image.shape[1] + w = image.shape[2] + + for x, label in zip(segmasks, detected_results[0]): + item_bbox = x[0] + item_mask = x[1] + + y1, x1, y2, x2 = item_bbox + + if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue + crop_region = core.make_crop_region(w, h, item_bbox, crop_factor) + + if detailer_hook is not None: + crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region) + + cropped_image = core.crop_image(image, crop_region) + cropped_mask = core.crop_ndarray2(item_mask, crop_region) + confidence = x[2] + # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) + + item = core.SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None) + + items.append(item) + + shape = image.shape[1], image.shape[2] + segs = shape, items + + if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): + segs = detailer_hook.post_detection(segs) + + return segs + + def detect_combined(self, image, threshold, dilation): + detected_results = inference_bbox(self.bbox_model, core.tensor2pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + return core.combine_masks(segmasks) + + def setAux(self, x): + pass diff --git a/custom_nodes/ComfyUI-ReActor/scripts/reactor_faceswap.py b/custom_nodes/ComfyUI-ReActor/scripts/reactor_faceswap.py new file mode 100644 index 0000000000000000000000000000000000000000..2c407dcc776ad020c84a8b15f020503f74626d2f --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/reactor_faceswap.py @@ -0,0 +1,193 @@ +import os, glob + +from PIL import Image + +import modules.scripts as scripts +# from modules.upscaler import Upscaler, UpscalerData +from modules import scripts, scripts_postprocessing +from modules.processing import ( + StableDiffusionProcessing, + StableDiffusionProcessingImg2Img, +) +from modules.shared import state +from scripts.reactor_logger import logger +from scripts.reactor_swapper import ( + swap_face, + swap_face_many, + get_current_faces_model, + analyze_faces, + half_det_size, + providers +) +import folder_paths +import comfy.model_management as model_management + + +def get_models(): + swappers = [ + "insightface", + "reswapper" + ] + models_list = [] + for folder in swappers: + models_folder = folder + "/*" + models_path = os.path.join(folder_paths.models_dir,models_folder) + models = glob.glob(models_path) + models = [x for x in models if x.endswith(".onnx") or x.endswith(".pth")] + models_list.extend(models) + return models_list + + +class FaceSwapScript(scripts.Script): + + def process( + self, + p: StableDiffusionProcessing, + img, + enable, + source_faces_index, + faces_index, + model, + swap_in_source, + swap_in_generated, + gender_source, + gender_target, + face_model, + faces_order, + face_boost_enabled, + face_restore_model, + face_restore_visibility, + codeformer_weight, + interpolation, + ): + self.enable = enable + if self.enable: + + self.source = img + self.swap_in_generated = swap_in_generated + self.gender_source = gender_source + self.gender_target = gender_target + self.model = model + self.face_model = face_model + self.faces_order = faces_order + self.face_boost_enabled = face_boost_enabled + self.face_restore_model = face_restore_model + self.face_restore_visibility = face_restore_visibility + self.codeformer_weight = codeformer_weight + self.interpolation = interpolation + self.source_faces_index = [ + int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric() + ] + self.faces_index = [ + int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() + ] + if len(self.source_faces_index) == 0: + self.source_faces_index = [0] + if len(self.faces_index) == 0: + self.faces_index = [0] + + if self.gender_source is None or self.gender_source == "no": + self.gender_source = 0 + elif self.gender_source == "female": + self.gender_source = 1 + elif self.gender_source == "male": + self.gender_source = 2 + + if self.gender_target is None or self.gender_target == "no": + self.gender_target = 0 + elif self.gender_target == "female": + self.gender_target = 1 + elif self.gender_target == "male": + self.gender_target = 2 + + # if self.source is not None: + if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source: + logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index) + + if len(p.init_images) == 1: + + result = swap_face( + self.source, + p.init_images[0], + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + gender_source=self.gender_source, + gender_target=self.gender_target, + face_model=self.face_model, + faces_order=self.faces_order, + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.face_restore_model, + face_restore_visibility=self.face_restore_visibility, + codeformer_weight=self.codeformer_weight, + interpolation=self.interpolation, + ) + p.init_images[0] = result + + # for i in range(len(p.init_images)): + # if state.interrupted or model_management.processing_interrupted(): + # logger.status("Interrupted by User") + # break + # if len(p.init_images) > 1: + # logger.status(f"Swap in %s", i) + # result = swap_face( + # self.source, + # p.init_images[i], + # source_faces_index=self.source_faces_index, + # faces_index=self.faces_index, + # model=self.model, + # gender_source=self.gender_source, + # gender_target=self.gender_target, + # face_model=self.face_model, + # ) + # p.init_images[i] = result + + elif len(p.init_images) > 1: + result = swap_face_many( + self.source, + p.init_images, + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + gender_source=self.gender_source, + gender_target=self.gender_target, + face_model=self.face_model, + faces_order=self.faces_order, + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.face_restore_model, + face_restore_visibility=self.face_restore_visibility, + codeformer_weight=self.codeformer_weight, + interpolation=self.interpolation, + ) + p.init_images = result + + logger.status("--Done!--") + # else: + # logger.error(f"Please provide a source face") + + def postprocess_batch(self, p, *args, **kwargs): + if self.enable: + images = kwargs["images"] + + def postprocess_image(self, p, script_pp: scripts.PostprocessImageArgs, *args): + if self.enable and self.swap_in_generated: + if self.source is not None: + logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index) + image: Image.Image = script_pp.image + result = swap_face( + self.source, + image, + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + upscale_options=self.upscale_options, + gender_source=self.gender_source, + gender_target=self.gender_target, + ) + try: + pp = scripts_postprocessing.PostprocessedImage(result) + pp.info = {} + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + except: + logger.error(f"Cannot create a result image") diff --git a/custom_nodes/ComfyUI-ReActor/scripts/reactor_logger.py b/custom_nodes/ComfyUI-ReActor/scripts/reactor_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..b6876489dd9f810e516644caed5f7c672417ffbd --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/reactor_logger.py @@ -0,0 +1,47 @@ +import logging +import copy +import sys + +from modules import shared +from reactor_utils import addLoggingLevel + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "STATUS": "\033[38;5;173m", # Calm ORANGE + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("ReActor") +logger.propagate = False + +# Add Custom Level +# logging.addLevelName(logging.INFO, "STATUS") +addLoggingLevel("STATUS", logging.INFO + 5) + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s",datefmt="%H:%M:%S") + ) + logger.addHandler(handler) + +# Configure logger +loglevel_string = getattr(shared.cmd_opts, "reactor_loglevel", "INFO") +loglevel = getattr(logging, loglevel_string.upper(), "info") +logger.setLevel(loglevel) diff --git a/custom_nodes/ComfyUI-ReActor/scripts/reactor_sfw.py b/custom_nodes/ComfyUI-ReActor/scripts/reactor_sfw.py new file mode 100644 index 0000000000000000000000000000000000000000..7c730c16dbb22d8a5a739eba046ed434cc96c743 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/reactor_sfw.py @@ -0,0 +1,34 @@ +from transformers import pipeline +from PIL import Image +import logging +import os +from reactor_utils import download +from scripts.reactor_logger import logger + +def ensure_nsfw_model(nsfwdet_model_path): + """Download NSFW detection model if it doesn't exist""" + if not os.path.exists(nsfwdet_model_path): + os.makedirs(nsfwdet_model_path) + nd_urls = [ + "https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/config.json", + "https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/model.safetensors", + "https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/preprocessor_config.json", + ] + for model_url in nd_urls: + model_name = os.path.basename(model_url) + model_path = os.path.join(nsfwdet_model_path, model_name) + download(model_url, model_path, model_name) + +SCORE = 0.96 + +logging.getLogger("transformers").setLevel(logging.ERROR) + +def nsfw_image(img_path: str, model_path: str): + ensure_nsfw_model(model_path) + with Image.open(img_path) as img: + predict = pipeline("image-classification", model=model_path) + result = predict(img) + if result[0]["label"] == "nsfw" and result[0]["score"] > SCORE: + logger.status(f"NSFW content detected, skipping...") + return True + return False diff --git a/custom_nodes/ComfyUI-ReActor/scripts/reactor_swapper.py b/custom_nodes/ComfyUI-ReActor/scripts/reactor_swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa9c7df30c6d40f421088ace36d3380badf2267 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/reactor_swapper.py @@ -0,0 +1,576 @@ +import os +import shutil +from typing import List, Union + +import cv2 +import numpy as np +from PIL import Image + +import insightface +from insightface.app.common import Face +# try: +# import torch.cuda as cuda +# except: +# cuda = None +import torch + +import folder_paths +import comfy.model_management as model_management +from modules.shared import state + +from scripts.reactor_logger import logger +from reactor_utils import ( + move_path, + get_image_md5hash, +) +from scripts.r_faceboost import swapper, restorer + +import warnings + +np.warnings = warnings +np.warnings.filterwarnings('ignore') + +# PROVIDERS +try: + if torch.cuda.is_available(): + providers = ["CUDAExecutionProvider"] + elif torch.backends.mps.is_available(): + providers = ["CoreMLExecutionProvider"] + elif hasattr(torch,'dml') or hasattr(torch,'privateuseone'): + providers = ["ROCMExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] +except Exception as e: + logger.debug(f"ExecutionProviderError: {e}.\nEP is set to CPU.") + providers = ["CPUExecutionProvider"] +# if cuda is not None: +# if cuda.is_available(): +# providers = ["CUDAExecutionProvider"] +# else: +# providers = ["CPUExecutionProvider"] +# else: +# providers = ["CPUExecutionProvider"] + +models_path_old = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") +insightface_path_old = os.path.join(models_path_old, "insightface") +insightface_models_path_old = os.path.join(insightface_path_old, "models") + +models_path = folder_paths.models_dir +insightface_path = os.path.join(models_path, "insightface") +insightface_models_path = os.path.join(insightface_path, "models") +reswapper_path = os.path.join(models_path, "reswapper") + +if os.path.exists(models_path_old): + move_path(insightface_models_path_old, insightface_models_path) + move_path(insightface_path_old, insightface_path) + move_path(models_path_old, models_path) +if os.path.exists(insightface_path) and os.path.exists(insightface_path_old): + shutil.rmtree(insightface_path_old) + shutil.rmtree(models_path_old) + + +FS_MODEL = None +CURRENT_FS_MODEL_PATH = None + +ANALYSIS_MODELS = { + "640": None, + "320": None, +} + +SOURCE_FACES = None +SOURCE_IMAGE_HASH = None +TARGET_FACES = None +TARGET_IMAGE_HASH = None +TARGET_FACES_LIST = [] +TARGET_IMAGE_LIST_HASH = [] + +def unload_model(model): + if model is not None: + # check if model has unload method + # if "unload" in model: + # model.unload() + # if "model_unload" in model: + # model.model_unload() + del model + return None + +def unload_all_models(): + global FS_MODEL, CURRENT_FS_MODEL_PATH + FS_MODEL = unload_model(FS_MODEL) + ANALYSIS_MODELS["320"] = unload_model(ANALYSIS_MODELS["320"]) + ANALYSIS_MODELS["640"] = unload_model(ANALYSIS_MODELS["640"]) + +def get_current_faces_model(): + global SOURCE_FACES + return SOURCE_FACES + +def getAnalysisModel(det_size = (640, 640)): + global ANALYSIS_MODELS + ANALYSIS_MODEL = ANALYSIS_MODELS[str(det_size[0])] + if ANALYSIS_MODEL is None: + ANALYSIS_MODEL = insightface.app.FaceAnalysis( + name="buffalo_l", providers=providers, root=insightface_path + ) + ANALYSIS_MODEL.prepare(ctx_id=0, det_size=det_size) + ANALYSIS_MODELS[str(det_size[0])] = ANALYSIS_MODEL + return ANALYSIS_MODEL + +def getFaceSwapModel(model_path: str): + global FS_MODEL, CURRENT_FS_MODEL_PATH + if FS_MODEL is None or CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path: + CURRENT_FS_MODEL_PATH = model_path + FS_MODEL = unload_model(FS_MODEL) + FS_MODEL = insightface.model_zoo.get_model(model_path, providers=providers) + + return FS_MODEL + + +def sort_by_order(face, order: str): + if order == "left-right": + return sorted(face, key=lambda x: x.bbox[0]) + if order == "right-left": + return sorted(face, key=lambda x: x.bbox[0], reverse = True) + if order == "top-bottom": + return sorted(face, key=lambda x: x.bbox[1]) + if order == "bottom-top": + return sorted(face, key=lambda x: x.bbox[1], reverse = True) + if order == "small-large": + return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) + # if order == "large-small": + # return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True) + # by default "large-small": + return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True) + +def get_face_gender( + face, + face_index, + gender_condition, + operated: str, + order: str, +): + gender = [ + x.sex + for x in face + ] + gender.reverse() + # If index is outside of bounds, return None, avoid exception + if face_index >= len(gender): + logger.status("Requested face index (%s) is out of bounds (max available index is %s)", face_index, len(gender)) + return None, 0 + face_gender = gender[face_index] + logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender) + if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"): + logger.status("OK - Detected Gender matches Condition") + try: + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 0 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0 + except IndexError: + return None, 0 + else: + logger.status("WRONG - Detected Gender doesn't match Condition") + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 1 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 1 + +def half_det_size(det_size): + logger.status("Trying to halve 'det_size' parameter") + return (det_size[0] // 2, det_size[1] // 2) + +def analyze_faces(img_data: np.ndarray, det_size=(640, 640)): + face_analyser = getAnalysisModel(det_size) + faces = face_analyser.get(img_data) + + # Try halving det_size if no faces are found + if len(faces) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return analyze_faces(img_data, det_size_half) + + return faces + +def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0, order="large-small"): + + buffalo_path = os.path.join(insightface_models_path, "buffalo_l.zip") + if os.path.exists(buffalo_path): + os.remove(buffalo_path) + + if gender_source != 0: + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + return get_face_gender(face,face_index,gender_source,"Source", order) + + if gender_target != 0: + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + return get_face_gender(face,face_index,gender_target,"Target", order) + + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + + try: + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 0 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0 + except IndexError: + return None, 0 + + +def swap_face( + source_img: Union[Image.Image, None], + target_img: Image.Image, + model: Union[str, None] = None, + source_faces_index: List[int] = [0], + faces_index: List[int] = [0], + gender_source: int = 0, + gender_target: int = 0, + face_model: Union[Face, None] = None, + faces_order: List = ["large-small", "large-small"], + face_boost_enabled: bool = False, + face_restore_model = None, + face_restore_visibility: int = 1, + codeformer_weight: float = 0.5, + interpolation: str = "Bicubic", +): + global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH + result_image = target_img + + if model is not None: + + if isinstance(source_img, str): # source_img is a base64 string + import base64, io + if 'base64,' in source_img: # check if the base64 string has a data URL scheme + # split the base64 string to get the actual base64 encoded image data + base64_data = source_img.split('base64,')[-1] + # decode base64 string to bytes + img_bytes = base64.b64decode(base64_data) + else: + # if no data URL scheme, just decode + img_bytes = base64.b64decode(source_img) + + source_img = Image.open(io.BytesIO(img_bytes)) + + target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) + + if source_img is not None: + + source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) + + source_image_md5hash = get_image_md5hash(source_img) + + if SOURCE_IMAGE_HASH is None: + SOURCE_IMAGE_HASH = source_image_md5hash + source_image_same = False + else: + source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False + if not source_image_same: + SOURCE_IMAGE_HASH = source_image_md5hash + + logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH) + logger.info("Source Image the Same? %s", source_image_same) + + if SOURCE_FACES is None or not source_image_same: + logger.status("Analyzing Source Image...") + source_faces = analyze_faces(source_img) + SOURCE_FACES = source_faces + elif source_image_same: + logger.status("Using Hashed Source Face(s) Model...") + source_faces = SOURCE_FACES + + elif face_model is not None: + + source_faces_index = [0] + logger.status("Using Loaded Source Face Model...") + source_face_model = [face_model] + source_faces = source_face_model + + else: + logger.error("Cannot detect any Source") + + if source_faces is not None: + + target_image_md5hash = get_image_md5hash(target_img) + + if TARGET_IMAGE_HASH is None: + TARGET_IMAGE_HASH = target_image_md5hash + target_image_same = False + else: + target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False + if not target_image_same: + TARGET_IMAGE_HASH = target_image_md5hash + + logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH) + logger.info("Target Image the Same? %s", target_image_same) + + if TARGET_FACES is None or not target_image_same: + logger.status("Analyzing Target Image...") + target_faces = analyze_faces(target_img) + TARGET_FACES = target_faces + elif target_image_same: + logger.status("Using Hashed Target Face(s) Model...") + target_faces = TARGET_FACES + + # No use in trying to swap faces if no faces are found, enhancement + if len(target_faces) == 0: + logger.status("Cannot detect any Target, skipping swapping...") + return result_image + + if source_img is not None: + # separated management of wrong_gender between source and target, enhancement + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1]) + else: + # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]] + source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]] + src_wrong_gender = 0 + + if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index): + logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.') + elif source_face is not None: + result = target_img + if "inswapper" in model: + model_path = os.path.join(insightface_path, model) + elif "reswapper" in model: + model_path = os.path.join(reswapper_path, model) + face_swapper = getFaceSwapModel(model_path) + + source_face_idx = 0 + + for face_num in faces_index: + # No use in trying to swap faces if no further faces are found, enhancement + if face_num >= len(target_faces): + logger.status("Checked all existing target faces, skipping swapping...") + break + + if len(source_faces_index) > 1 and source_face_idx > 0: + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1]) + source_face_idx += 1 + + if source_face is not None and src_wrong_gender == 0: + target_face, wrong_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target, order=faces_order[0]) + if target_face is not None and wrong_gender == 0: + logger.status(f"Swapping...") + if face_boost_enabled: + logger.status(f"Face Boost is enabled") + bgr_fake, M = face_swapper.get(result, target_face, source_face, paste_back=False) + bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation) + M *= scale + result = swapper.in_swap(target_img, bgr_fake, M) + else: + # logger.status(f"Swapping as-is") + result = face_swapper.get(result, target_face, source_face) + elif wrong_gender == 1: + wrong_gender = 0 + # Keep searching for other faces if wrong gender is detected, enhancement + #if source_face_idx == len(source_faces_index): + # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + # return result_image + logger.status("Wrong target gender detected") + continue + else: + logger.status(f"No target face found for {face_num}") + elif src_wrong_gender == 1: + src_wrong_gender = 0 + # Keep searching for other faces if wrong gender is detected, enhancement + #if source_face_idx == len(source_faces_index): + # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + # return result_image + logger.status("Wrong source gender detected") + continue + else: + logger.status(f"No source face found for face number {source_face_idx}.") + + result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + + else: + logger.status("No source face(s) in the provided Index") + else: + logger.status("No source face(s) found") + return result_image + +def swap_face_many( + source_img: Union[Image.Image, None], + target_imgs: List[Image.Image], + model: Union[str, None] = None, + source_faces_index: List[int] = [0], + faces_index: List[int] = [0], + gender_source: int = 0, + gender_target: int = 0, + face_model: Union[Face, None] = None, + faces_order: List = ["large-small", "large-small"], + face_boost_enabled: bool = False, + face_restore_model = None, + face_restore_visibility: int = 1, + codeformer_weight: float = 0.5, + interpolation: str = "Bicubic", +): + global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, TARGET_FACES_LIST, TARGET_IMAGE_LIST_HASH + result_images = target_imgs + + if model is not None: + + if isinstance(source_img, str): # source_img is a base64 string + import base64, io + if 'base64,' in source_img: # check if the base64 string has a data URL scheme + # split the base64 string to get the actual base64 encoded image data + base64_data = source_img.split('base64,')[-1] + # decode base64 string to bytes + img_bytes = base64.b64decode(base64_data) + else: + # if no data URL scheme, just decode + img_bytes = base64.b64decode(source_img) + + source_img = Image.open(io.BytesIO(img_bytes)) + + target_imgs = [cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) for target_img in target_imgs] + + if source_img is not None: + + source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) + + source_image_md5hash = get_image_md5hash(source_img) + + if SOURCE_IMAGE_HASH is None: + SOURCE_IMAGE_HASH = source_image_md5hash + source_image_same = False + else: + source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False + if not source_image_same: + SOURCE_IMAGE_HASH = source_image_md5hash + + logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH) + logger.info("Source Image the Same? %s", source_image_same) + + if SOURCE_FACES is None or not source_image_same: + logger.status("Analyzing Source Image...") + source_faces = analyze_faces(source_img) + SOURCE_FACES = source_faces + elif source_image_same: + logger.status("Using Hashed Source Face(s) Model...") + source_faces = SOURCE_FACES + + elif face_model is not None: + + source_faces_index = [0] + logger.status("Using Loaded Source Face Model...") + source_face_model = [face_model] + source_faces = source_face_model + + else: + logger.error("Cannot detect any Source") + + if source_faces is not None: + + target_faces = [] + for i, target_img in enumerate(target_imgs): + if state.interrupted or model_management.processing_interrupted(): + logger.status("Interrupted by User") + break + + target_image_md5hash = get_image_md5hash(target_img) + if len(TARGET_IMAGE_LIST_HASH) == 0: + TARGET_IMAGE_LIST_HASH = [target_image_md5hash] + target_image_same = False + elif len(TARGET_IMAGE_LIST_HASH) == i: + TARGET_IMAGE_LIST_HASH.append(target_image_md5hash) + target_image_same = False + else: + target_image_same = True if TARGET_IMAGE_LIST_HASH[i] == target_image_md5hash else False + if not target_image_same: + TARGET_IMAGE_LIST_HASH[i] = target_image_md5hash + + logger.info("(Image %s) Target Image MD5 Hash = %s", i, TARGET_IMAGE_LIST_HASH[i]) + logger.info("(Image %s) Target Image the Same? %s", i, target_image_same) + + if len(TARGET_FACES_LIST) == 0: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST = [target_face] + elif len(TARGET_FACES_LIST) == i and not target_image_same: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST.append(target_face) + elif len(TARGET_FACES_LIST) != i and not target_image_same: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST[i] = target_face + elif target_image_same: + logger.status("(Image %s) Using Hashed Target Face(s) Model...", i) + target_face = TARGET_FACES_LIST[i] + + + # logger.status(f"Analyzing Target Image {i}...") + # target_face = analyze_faces(target_img) + if target_face is not None: + target_faces.append(target_face) + + # No use in trying to swap faces if no faces are found, enhancement + if len(target_faces) == 0: + logger.status("Cannot detect any Target, skipping swapping...") + return result_images + + if source_img is not None: + # separated management of wrong_gender between source and target, enhancement + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1]) + else: + # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]] + source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]] + src_wrong_gender = 0 + + if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index): + logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.') + elif source_face is not None: + results = target_imgs + model_path = model_path = os.path.join(insightface_path, model) + face_swapper = getFaceSwapModel(model_path) + + source_face_idx = 0 + + for face_num in faces_index: + # No use in trying to swap faces if no further faces are found, enhancement + if face_num >= len(target_faces): + logger.status("Checked all existing target faces, skipping swapping...") + break + + if len(source_faces_index) > 1 and source_face_idx > 0: + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1]) + source_face_idx += 1 + + if source_face is not None and src_wrong_gender == 0: + # Reading results to make current face swap on a previous face result + for i, (target_img, target_face) in enumerate(zip(results, target_faces)): + target_face_single, wrong_gender = get_face_single(target_img, target_face, face_index=face_num, gender_target=gender_target, order=faces_order[0]) + if target_face_single is not None and wrong_gender == 0: + result = target_img + logger.status(f"Swapping {i}...") + if face_boost_enabled: + logger.status(f"Face Boost is enabled") + bgr_fake, M = face_swapper.get(target_img, target_face_single, source_face, paste_back=False) + bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation) + M *= scale + result = swapper.in_swap(target_img, bgr_fake, M) + else: + # logger.status(f"Swapping as-is") + result = face_swapper.get(target_img, target_face_single, source_face) + results[i] = result + elif wrong_gender == 1: + wrong_gender = 0 + logger.status("Wrong target gender detected") + continue + else: + logger.status(f"No target face found for {face_num}") + elif src_wrong_gender == 1: + src_wrong_gender = 0 + logger.status("Wrong source gender detected") + continue + else: + logger.status(f"No source face found for face number {source_face_idx}.") + + result_images = [Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) for result in results] + + else: + logger.status("No source face(s) in the provided Index") + else: + logger.status("No source face(s) found") + return result_images diff --git a/custom_nodes/ComfyUI-ReActor/scripts/reactor_version.py b/custom_nodes/ComfyUI-ReActor/scripts/reactor_version.py new file mode 100644 index 0000000000000000000000000000000000000000..df097c60e3005b25147fc0573631952fb144e164 --- /dev/null +++ b/custom_nodes/ComfyUI-ReActor/scripts/reactor_version.py @@ -0,0 +1,13 @@ +app_title = "ReActor Node for ComfyUI" +version_flag = "v0.6.0-a1" + +COLORS = { + "CYAN": "\033[0;36m", # CYAN + "ORANGE": "\033[38;5;173m", # Calm ORANGE + "GREEN": "\033[0;32m", # GREEN + "YELLOW": "\033[0;33m", # YELLOW + "RED": "\033[0;91m", # RED + "0": "\033[0m", # RESET COLOR +} + +print(f"\n{COLORS['YELLOW']}[ReActor]{COLORS['0']} - {COLORS['ORANGE']}STATUS{COLORS['0']} - {COLORS['GREEN']}Running {version_flag} in ComfyUI{COLORS['0']}") diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example new file mode 100644 index 0000000000000000000000000000000000000000..29ab2aa72319354b147b7dd79e1c3179e54d3d06 --- /dev/null +++ b/custom_nodes/example_node.py.example @@ -0,0 +1,155 @@ +class Example: + """ + A example node + + Class methods + ------------- + INPUT_TYPES (dict): + Tell the main program input parameters of nodes. + IS_CHANGED: + optional method to control when the node is re executed. + + Attributes + ---------- + RETURN_TYPES (`tuple`): + The type of each element in the output tuple. + RETURN_NAMES (`tuple`): + Optional: The name of each output in the output tuple. + FUNCTION (`str`): + The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() + OUTPUT_NODE ([`bool`]): + If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. + The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. + Assumed to be False if not present. + CATEGORY (`str`): + The category the node should appear in the UI. + DEPRECATED (`bool`): + Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain + functional in existing workflows that use them. + EXPERIMENTAL (`bool`): + Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to + significant changes or removal in future versions. Use with caution in production workflows. + execute(s) -> tuple || None: + The entry point method. The name of this method must be the same as the value of property `FUNCTION`. + For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. + """ + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + """ + Return a dictionary which contains config for all input fields. + Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". + Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. + The type can be a list for selection. + + Returns: `dict`: + - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` + - Value input_fields (`dict`): Contains input fields config: + * Key field_name (`string`): Name of a entry-point method's argument + * Value field_config (`tuple`): + + First value is a string indicate the type of field or a list for selection. + + Second value is a config for type "INT", "STRING" or "FLOAT". + """ + return { + "required": { + "image": ("IMAGE",), + "int_field": ("INT", { + "default": 0, + "min": 0, #Minimum value + "max": 4096, #Maximum value + "step": 64, #Slider's step + "display": "number", # Cosmetic only: display as "number" or "slider" + "lazy": True # Will only be evaluated if check_lazy_status requires it + }), + "float_field": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 10.0, + "step": 0.01, + "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + "display": "number", + "lazy": True + }), + "print_to_screen": (["enable", "disable"],), + "string_field": ("STRING", { + "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node + "default": "Hello World!", + "lazy": True + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + #RETURN_NAMES = ("image_output_name",) + + FUNCTION = "test" + + #OUTPUT_NODE = False + + CATEGORY = "Example" + + def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + """ + Return a list of input names that need to be evaluated. + + This function will be called if there are any lazy inputs which have not yet been + evaluated. As long as you return at least one field which has not yet been evaluated + (and more exist), this function will be called again once the value of the requested + field is available. + + Any evaluated inputs will be passed as arguments to this function. Any unevaluated + inputs will have the value None. + """ + if print_to_screen == "enable": + return ["int_field", "float_field", "string_field"] + else: + return [] + + def test(self, image, string_field, int_field, float_field, print_to_screen): + if print_to_screen == "enable": + print(f"""Your input contains: + string_field aka input text: {string_field} + int_field: {int_field} + float_field: {float_field} + """) + #do some processing on the image, in this example I just invert it + image = 1.0 - image + return (image,) + + """ + The node will always be re executed if any of the inputs change but + this method can be used to force the node to execute again even when the inputs don't change. + You can make this node return a number or a string. This value will be compared to the one returned the last time the node was + executed, if it is different the node will be executed again. + This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash + changes between executions the LoadImage node is executed again. + """ + #@classmethod + #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + # return "" + +# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension +# WEB_DIRECTORY = "./somejs" + + +# Add custom API routes, using router +from aiohttp import web +from server import PromptServer + +@PromptServer.instance.routes.get("/hello") +async def get_hello(request): + return web.json_response("hello") + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "Example": Example +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "Example": "Example Node" +} diff --git a/custom_nodes/websocket_image_save.py b/custom_nodes/websocket_image_save.py new file mode 100644 index 0000000000000000000000000000000000000000..15f87f9f56175f33df18c6142f9e13c4503b1186 --- /dev/null +++ b/custom_nodes/websocket_image_save.py @@ -0,0 +1,44 @@ +from PIL import Image +import numpy as np +import comfy.utils +import time + +#You can use this node to save full size images through the websocket, the +#images will be sent in exactly the same format as the image previews: as +#binary images on the websocket with a 8 byte header indicating the type +#of binary message (first 4 bytes) and the image format (next 4 bytes). + +#Note that no metadata will be put in the images saved with this node. + +class SaveImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ),} + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "api/image" + + def save_images(self, images): + pbar = comfy.utils.ProgressBar(images.shape[0]) + step = 0 + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) + step += 1 + + return {} + + @classmethod + def IS_CHANGED(s, images): + return time.time() + +NODE_CLASS_MAPPINGS = { + "SaveImageWebsocket": SaveImageWebsocket, +}