Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- LICENSE +201 -0
- app.py +821 -0
- assets/mask2face/black_seg.png +0 -0
- assets/mask2face/handou_seg.png +0 -0
- assets/multimodal/liuyifei_seg.png +0 -0
- assets/multimodal/musk_seg.png +0 -0
- config/Face-MoGLE.yaml +40 -0
- requirements.txt +17 -0
- src/flux/__pycache__/block.cpython-311.pyc +0 -0
- src/flux/__pycache__/condition.cpython-311.pyc +0 -0
- src/flux/__pycache__/generate.cpython-311.pyc +0 -0
- src/flux/__pycache__/lora_controller.cpython-311.pyc +0 -0
- src/flux/__pycache__/pipeline_tools.cpython-311.pyc +0 -0
- src/flux/__pycache__/transformer.cpython-311.pyc +0 -0
- src/flux/block.py +345 -0
- src/flux/condition.py +129 -0
- src/flux/generate.py +316 -0
- src/flux/lora_controller.py +75 -0
- src/flux/pipeline_tools.py +52 -0
- src/flux/transformer.py +257 -0
- src/moe/__pycache__/mogle.cpython-311.pyc +0 -0
- src/moe/mogle.py +140 -0
- src/train/callbacks.py +170 -0
- src/train/data.py +98 -0
- src/train/model.py +201 -0
- weights/mogle.pt +3 -0
- weights/pytorch_lora_weights.safetensors +3 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2024] [Zhenxiong Tan]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
import numpy as np
|
6 |
+
from torchvision.models import convnext_base, convnext_small
|
7 |
+
from torch import nn as nn
|
8 |
+
import facer
|
9 |
+
from torch import Tensor
|
10 |
+
import math
|
11 |
+
from typing import Any, Optional, Tuple, Type
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import torchvision
|
14 |
+
from torchvision import transforms as T
|
15 |
+
from src.flux.generate import generate
|
16 |
+
from diffusers.pipelines import FluxPipeline
|
17 |
+
from src.flux.condition import Condition
|
18 |
+
from src.moe.mogle import MoGLE
|
19 |
+
|
20 |
+
|
21 |
+
class LayerNorm2d(nn.Module):
|
22 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
25 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
26 |
+
self.eps = eps
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
29 |
+
u = x.mean(1, keepdim=True)
|
30 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
31 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
32 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class MLP(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
input_dim: int,
|
40 |
+
hidden_dim: int,
|
41 |
+
output_dim: int,
|
42 |
+
num_layers: int,
|
43 |
+
sigmoid_output: bool = False,
|
44 |
+
) -> None:
|
45 |
+
super().__init__()
|
46 |
+
self.num_layers = num_layers
|
47 |
+
h = [hidden_dim] * (num_layers - 1)
|
48 |
+
self.layers = nn.ModuleList(
|
49 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
50 |
+
)
|
51 |
+
self.sigmoid_output = sigmoid_output
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
for i, layer in enumerate(self.layers):
|
55 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
56 |
+
if self.sigmoid_output:
|
57 |
+
x = F.sigmoid(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class FaceDecoder(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
*,
|
65 |
+
transformer_dim: 256,
|
66 |
+
transformer: nn.Module,
|
67 |
+
activation: Type[nn.Module] = nn.GELU,
|
68 |
+
) -> None:
|
69 |
+
|
70 |
+
super().__init__()
|
71 |
+
self.transformer_dim = transformer_dim
|
72 |
+
self.transformer = transformer
|
73 |
+
|
74 |
+
self.background_token = nn.Embedding(1, transformer_dim)
|
75 |
+
self.neck_token = nn.Embedding(1, transformer_dim)
|
76 |
+
self.face_token = nn.Embedding(1, transformer_dim)
|
77 |
+
self.cloth_token = nn.Embedding(1, transformer_dim)
|
78 |
+
self.rightear_token = nn.Embedding(1, transformer_dim)
|
79 |
+
self.leftear_token = nn.Embedding(1, transformer_dim)
|
80 |
+
self.rightbro_token = nn.Embedding(1, transformer_dim)
|
81 |
+
self.leftbro_token = nn.Embedding(1, transformer_dim)
|
82 |
+
self.righteye_token = nn.Embedding(1, transformer_dim)
|
83 |
+
self.lefteye_token = nn.Embedding(1, transformer_dim)
|
84 |
+
self.nose_token = nn.Embedding(1, transformer_dim)
|
85 |
+
self.innermouth_token = nn.Embedding(1, transformer_dim)
|
86 |
+
self.lowerlip_token = nn.Embedding(1, transformer_dim)
|
87 |
+
self.upperlip_token = nn.Embedding(1, transformer_dim)
|
88 |
+
self.hair_token = nn.Embedding(1, transformer_dim)
|
89 |
+
self.glass_token = nn.Embedding(1, transformer_dim)
|
90 |
+
self.hat_token = nn.Embedding(1, transformer_dim)
|
91 |
+
self.earring_token = nn.Embedding(1, transformer_dim)
|
92 |
+
self.necklace_token = nn.Embedding(1, transformer_dim)
|
93 |
+
|
94 |
+
self.output_upscaling = nn.Sequential(
|
95 |
+
nn.ConvTranspose2d(
|
96 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
97 |
+
),
|
98 |
+
LayerNorm2d(transformer_dim // 4),
|
99 |
+
activation(),
|
100 |
+
nn.ConvTranspose2d(
|
101 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
102 |
+
),
|
103 |
+
activation(),
|
104 |
+
)
|
105 |
+
|
106 |
+
self.output_hypernetwork_mlps = MLP(
|
107 |
+
transformer_dim, transformer_dim, transformer_dim // 8, 3
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(
|
111 |
+
self,
|
112 |
+
image_embeddings: torch.Tensor,
|
113 |
+
image_pe: torch.Tensor,
|
114 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
115 |
+
"""
|
116 |
+
image_embeddings - torch.Size([1, 256, 128, 128])
|
117 |
+
image_pe - torch.Size([1, 256, 128, 128])
|
118 |
+
"""
|
119 |
+
output_tokens = torch.cat(
|
120 |
+
[
|
121 |
+
self.background_token.weight,
|
122 |
+
self.neck_token.weight,
|
123 |
+
self.face_token.weight,
|
124 |
+
self.cloth_token.weight,
|
125 |
+
self.rightear_token.weight,
|
126 |
+
self.leftear_token.weight,
|
127 |
+
self.rightbro_token.weight,
|
128 |
+
self.leftbro_token.weight,
|
129 |
+
self.righteye_token.weight,
|
130 |
+
self.lefteye_token.weight,
|
131 |
+
self.nose_token.weight,
|
132 |
+
self.innermouth_token.weight,
|
133 |
+
self.lowerlip_token.weight,
|
134 |
+
self.upperlip_token.weight,
|
135 |
+
self.hair_token.weight,
|
136 |
+
self.glass_token.weight,
|
137 |
+
self.hat_token.weight,
|
138 |
+
self.earring_token.weight,
|
139 |
+
self.necklace_token.weight,
|
140 |
+
],
|
141 |
+
dim=0,
|
142 |
+
)
|
143 |
+
|
144 |
+
tokens = output_tokens.unsqueeze(0).expand(
|
145 |
+
image_embeddings.size(0), -1, -1
|
146 |
+
) ##### torch.Size([4, 11, 256])
|
147 |
+
|
148 |
+
src = image_embeddings ##### torch.Size([4, 256, 128, 128])
|
149 |
+
pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
|
150 |
+
b, c, h, w = src.shape
|
151 |
+
|
152 |
+
# Run the transformer
|
153 |
+
hs, src = self.transformer(
|
154 |
+
src, pos_src, tokens
|
155 |
+
) ####### hs - torch.Size([BS, 11, 256]), src - torch.Size([BS, 16348, 256])
|
156 |
+
mask_token_out = hs[:, :, :]
|
157 |
+
|
158 |
+
src = src.transpose(1, 2).view(b, c, h, w) ##### torch.Size([4, 256, 128, 128])
|
159 |
+
upscaled_embedding = self.output_upscaling(
|
160 |
+
src
|
161 |
+
) ##### torch.Size([4, 32, 512, 512])
|
162 |
+
hyper_in = self.output_hypernetwork_mlps(
|
163 |
+
mask_token_out
|
164 |
+
) ##### torch.Size([1, 11, 32])
|
165 |
+
b, c, h, w = upscaled_embedding.shape
|
166 |
+
seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
|
167 |
+
b, -1, h, w
|
168 |
+
) ##### torch.Size([1, 11, 512, 512])
|
169 |
+
|
170 |
+
return seg_output
|
171 |
+
|
172 |
+
|
173 |
+
class PositionEmbeddingRandom(nn.Module):
|
174 |
+
"""
|
175 |
+
Positional encoding using random spatial frequencies.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
179 |
+
super().__init__()
|
180 |
+
if scale is None or scale <= 0.0:
|
181 |
+
scale = 1.0
|
182 |
+
self.register_buffer(
|
183 |
+
"positional_encoding_gaussian_matrix",
|
184 |
+
scale * torch.randn((2, num_pos_feats)),
|
185 |
+
)
|
186 |
+
|
187 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
188 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
189 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
190 |
+
coords = 2 * coords - 1
|
191 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
192 |
+
coords = 2 * np.pi * coords
|
193 |
+
# outputs d_1 x ... x d_n x C shape
|
194 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
195 |
+
|
196 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
197 |
+
"""Generate positional encoding for a grid of the specified size."""
|
198 |
+
h, w = size
|
199 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
200 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
201 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
202 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
203 |
+
y_embed = y_embed / h
|
204 |
+
x_embed = x_embed / w
|
205 |
+
|
206 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
207 |
+
return pe.permute(2, 0, 1) # C x H x W
|
208 |
+
|
209 |
+
def forward_with_coords(
|
210 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
211 |
+
) -> torch.Tensor:
|
212 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
213 |
+
coords = coords_input.clone()
|
214 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
215 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
216 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
217 |
+
|
218 |
+
|
219 |
+
class TwoWayTransformer(nn.Module):
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
depth: int,
|
223 |
+
embedding_dim: int,
|
224 |
+
num_heads: int,
|
225 |
+
mlp_dim: int,
|
226 |
+
activation: Type[nn.Module] = nn.ReLU,
|
227 |
+
attention_downsample_rate: int = 2,
|
228 |
+
) -> None:
|
229 |
+
"""
|
230 |
+
A transformer decoder that attends to an input image using
|
231 |
+
queries whose positional embedding is supplied.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
depth (int): number of layers in the transformer
|
235 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
236 |
+
num_heads (int): the number of heads for multihead attention. Must
|
237 |
+
divide embedding_dim
|
238 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
239 |
+
activation (nn.Module): the activation to use in the MLP block
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
self.depth = depth
|
243 |
+
self.embedding_dim = embedding_dim
|
244 |
+
self.num_heads = num_heads
|
245 |
+
self.mlp_dim = mlp_dim
|
246 |
+
self.layers = nn.ModuleList()
|
247 |
+
|
248 |
+
for i in range(depth):
|
249 |
+
self.layers.append(
|
250 |
+
TwoWayAttentionBlock(
|
251 |
+
embedding_dim=embedding_dim,
|
252 |
+
num_heads=num_heads,
|
253 |
+
mlp_dim=mlp_dim,
|
254 |
+
activation=activation,
|
255 |
+
attention_downsample_rate=attention_downsample_rate,
|
256 |
+
skip_first_layer_pe=(i == 0),
|
257 |
+
)
|
258 |
+
)
|
259 |
+
|
260 |
+
self.final_attn_token_to_image = Attention(
|
261 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
262 |
+
)
|
263 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
264 |
+
|
265 |
+
def forward(
|
266 |
+
self,
|
267 |
+
image_embedding: Tensor,
|
268 |
+
image_pe: Tensor,
|
269 |
+
point_embedding: Tensor,
|
270 |
+
) -> Tuple[Tensor, Tensor]:
|
271 |
+
"""
|
272 |
+
Args:
|
273 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
274 |
+
B x embedding_dim x h x w for any h and w.
|
275 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
276 |
+
have the same shape as image_embedding.
|
277 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
278 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
torch.Tensor: the processed point_embedding
|
282 |
+
torch.Tensor: the processed image_embedding
|
283 |
+
"""
|
284 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
285 |
+
bs, c, h, w = image_embedding.shape
|
286 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
287 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
288 |
+
|
289 |
+
# Prepare queries
|
290 |
+
queries = point_embedding
|
291 |
+
keys = image_embedding
|
292 |
+
|
293 |
+
# Apply transformer blocks and final layernorm
|
294 |
+
for layer in self.layers:
|
295 |
+
queries, keys = layer(
|
296 |
+
queries=queries,
|
297 |
+
keys=keys,
|
298 |
+
query_pe=point_embedding,
|
299 |
+
key_pe=image_pe,
|
300 |
+
)
|
301 |
+
|
302 |
+
# Apply the final attention layer from the points to the image
|
303 |
+
q = queries + point_embedding
|
304 |
+
k = keys + image_pe
|
305 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
306 |
+
queries = queries + attn_out
|
307 |
+
queries = self.norm_final_attn(queries)
|
308 |
+
|
309 |
+
return queries, keys
|
310 |
+
|
311 |
+
|
312 |
+
class MLPBlock(nn.Module):
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
embedding_dim: int,
|
316 |
+
mlp_dim: int,
|
317 |
+
act: Type[nn.Module] = nn.GELU,
|
318 |
+
) -> None:
|
319 |
+
super().__init__()
|
320 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
321 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
322 |
+
self.act = act()
|
323 |
+
|
324 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
325 |
+
return self.lin2(self.act(self.lin1(x)))
|
326 |
+
|
327 |
+
|
328 |
+
class TwoWayAttentionBlock(nn.Module):
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
embedding_dim: int,
|
332 |
+
num_heads: int,
|
333 |
+
mlp_dim: int = 2048,
|
334 |
+
activation: Type[nn.Module] = nn.ReLU,
|
335 |
+
attention_downsample_rate: int = 2,
|
336 |
+
skip_first_layer_pe: bool = False,
|
337 |
+
) -> None:
|
338 |
+
"""
|
339 |
+
A transformer block with four layers: (1) self-attention of sparse
|
340 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
341 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
342 |
+
inputs.
|
343 |
+
|
344 |
+
Arguments:
|
345 |
+
embedding_dim (int): the channel dimension of the embeddings
|
346 |
+
num_heads (int): the number of heads in the attention layers
|
347 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
348 |
+
activation (nn.Module): the activation of the mlp block
|
349 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
350 |
+
"""
|
351 |
+
super().__init__()
|
352 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
353 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
354 |
+
|
355 |
+
self.cross_attn_token_to_image = Attention(
|
356 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
357 |
+
)
|
358 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
359 |
+
|
360 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
361 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
362 |
+
|
363 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
364 |
+
self.cross_attn_image_to_token = Attention(
|
365 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
366 |
+
)
|
367 |
+
|
368 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
369 |
+
|
370 |
+
def forward(
|
371 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
372 |
+
) -> Tuple[Tensor, Tensor]:
|
373 |
+
# Self attention block
|
374 |
+
if self.skip_first_layer_pe:
|
375 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
376 |
+
else:
|
377 |
+
q = queries + query_pe
|
378 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
379 |
+
queries = queries + attn_out
|
380 |
+
queries = self.norm1(queries)
|
381 |
+
|
382 |
+
# Cross attention block, tokens attending to image embedding
|
383 |
+
q = queries + query_pe
|
384 |
+
k = keys + key_pe
|
385 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
386 |
+
queries = queries + attn_out
|
387 |
+
queries = self.norm2(queries)
|
388 |
+
|
389 |
+
# MLP block
|
390 |
+
mlp_out = self.mlp(queries)
|
391 |
+
queries = queries + mlp_out
|
392 |
+
queries = self.norm3(queries)
|
393 |
+
|
394 |
+
# Cross attention block, image embedding attending to tokens
|
395 |
+
q = queries + query_pe
|
396 |
+
k = keys + key_pe
|
397 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
398 |
+
keys = keys + attn_out
|
399 |
+
keys = self.norm4(keys)
|
400 |
+
|
401 |
+
return queries, keys
|
402 |
+
|
403 |
+
|
404 |
+
class Attention(nn.Module):
|
405 |
+
"""
|
406 |
+
An attention layer that allows for downscaling the size of the embedding
|
407 |
+
after projection to queries, keys, and values.
|
408 |
+
"""
|
409 |
+
|
410 |
+
def __init__(
|
411 |
+
self,
|
412 |
+
embedding_dim: int,
|
413 |
+
num_heads: int,
|
414 |
+
downsample_rate: int = 1,
|
415 |
+
) -> None:
|
416 |
+
super().__init__()
|
417 |
+
self.embedding_dim = embedding_dim
|
418 |
+
self.internal_dim = embedding_dim // downsample_rate
|
419 |
+
self.num_heads = num_heads
|
420 |
+
assert (
|
421 |
+
self.internal_dim % num_heads == 0
|
422 |
+
), "num_heads must divide embedding_dim."
|
423 |
+
|
424 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
425 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
426 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
427 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
428 |
+
|
429 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
430 |
+
b, n, c = x.shape
|
431 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
432 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
433 |
+
|
434 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
435 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
436 |
+
x = x.transpose(1, 2)
|
437 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
438 |
+
|
439 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
440 |
+
# Input projections
|
441 |
+
q = self.q_proj(q)
|
442 |
+
k = self.k_proj(k)
|
443 |
+
v = self.v_proj(v)
|
444 |
+
|
445 |
+
# Separate into heads
|
446 |
+
q = self._separate_heads(q, self.num_heads)
|
447 |
+
k = self._separate_heads(k, self.num_heads)
|
448 |
+
v = self._separate_heads(v, self.num_heads)
|
449 |
+
|
450 |
+
# Attention
|
451 |
+
_, _, _, c_per_head = q.shape
|
452 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
453 |
+
attn = attn / math.sqrt(c_per_head)
|
454 |
+
attn = torch.softmax(attn, dim=-1)
|
455 |
+
|
456 |
+
# Get output
|
457 |
+
out = attn @ v
|
458 |
+
out = self._recombine_heads(out)
|
459 |
+
out = self.out_proj(out)
|
460 |
+
|
461 |
+
return out
|
462 |
+
|
463 |
+
|
464 |
+
class SegfaceMLP(nn.Module):
|
465 |
+
"""
|
466 |
+
Linear Embedding.
|
467 |
+
"""
|
468 |
+
|
469 |
+
def __init__(self, input_dim):
|
470 |
+
super().__init__()
|
471 |
+
self.proj = nn.Linear(input_dim, 256)
|
472 |
+
|
473 |
+
def forward(self, hidden_states: torch.Tensor):
|
474 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
475 |
+
hidden_states = self.proj(hidden_states)
|
476 |
+
return hidden_states
|
477 |
+
|
478 |
+
|
479 |
+
class SegFaceCeleb(nn.Module):
|
480 |
+
def __init__(self, input_resolution, model):
|
481 |
+
super(SegFaceCeleb, self).__init__()
|
482 |
+
self.input_resolution = input_resolution
|
483 |
+
self.model = model
|
484 |
+
|
485 |
+
if self.model == "convnext_base":
|
486 |
+
convnext = convnext_base(pretrained=False)
|
487 |
+
self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
|
488 |
+
self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
|
489 |
+
self.multi_scale_features = []
|
490 |
+
|
491 |
+
if self.model == "convnext_small":
|
492 |
+
convnext = convnext_small(pretrained=False)
|
493 |
+
self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
|
494 |
+
self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
|
495 |
+
self.multi_scale_features = []
|
496 |
+
|
497 |
+
if self.model == "convnext_tiny":
|
498 |
+
convnext = convnext_small(pretrained=False)
|
499 |
+
self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
|
500 |
+
self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
|
501 |
+
self.multi_scale_features = []
|
502 |
+
|
503 |
+
embed_dim = 1024
|
504 |
+
out_chans = 256
|
505 |
+
|
506 |
+
self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
|
507 |
+
|
508 |
+
for name, module in self.backbone.named_modules():
|
509 |
+
if name in self.target_layer_names:
|
510 |
+
module.register_forward_hook(self.save_features_hook(name))
|
511 |
+
|
512 |
+
self.face_decoder = FaceDecoder(
|
513 |
+
transformer_dim=256,
|
514 |
+
transformer=TwoWayTransformer(
|
515 |
+
depth=2,
|
516 |
+
embedding_dim=256,
|
517 |
+
mlp_dim=2048,
|
518 |
+
num_heads=8,
|
519 |
+
),
|
520 |
+
)
|
521 |
+
|
522 |
+
num_encoder_blocks = 4
|
523 |
+
if self.model in ["swin_base", "swinv2_base", "convnext_base"]:
|
524 |
+
hidden_sizes = [128, 256, 512, 1024] ### Swin Base and ConvNext Base
|
525 |
+
if self.model in ["resnet"]:
|
526 |
+
hidden_sizes = [256, 512, 1024, 2048] ### ResNet
|
527 |
+
if self.model in [
|
528 |
+
"swinv2_small",
|
529 |
+
"swinv2_tiny",
|
530 |
+
"convnext_small",
|
531 |
+
"convnext_tiny",
|
532 |
+
]:
|
533 |
+
hidden_sizes = [
|
534 |
+
96,
|
535 |
+
192,
|
536 |
+
384,
|
537 |
+
768,
|
538 |
+
] ### Swin Small/Tiny and ConvNext Small/Tiny
|
539 |
+
if self.model in ["mobilenet"]:
|
540 |
+
hidden_sizes = [24, 40, 112, 960] ### MobileNet
|
541 |
+
if self.model in ["efficientnet"]:
|
542 |
+
hidden_sizes = [48, 80, 176, 1280] ### EfficientNet
|
543 |
+
decoder_hidden_size = 256
|
544 |
+
|
545 |
+
mlps = []
|
546 |
+
for i in range(num_encoder_blocks):
|
547 |
+
mlp = SegfaceMLP(input_dim=hidden_sizes[i])
|
548 |
+
mlps.append(mlp)
|
549 |
+
self.linear_c = nn.ModuleList(mlps)
|
550 |
+
|
551 |
+
# The following 3 layers implement the ConvModule of the original implementation
|
552 |
+
self.linear_fuse = nn.Conv2d(
|
553 |
+
in_channels=decoder_hidden_size * num_encoder_blocks,
|
554 |
+
out_channels=decoder_hidden_size,
|
555 |
+
kernel_size=1,
|
556 |
+
bias=False,
|
557 |
+
)
|
558 |
+
|
559 |
+
def save_features_hook(self, name):
|
560 |
+
def hook(module, input, output):
|
561 |
+
if self.model in [
|
562 |
+
"swin_base",
|
563 |
+
"swinv2_base",
|
564 |
+
"swinv2_small",
|
565 |
+
"swinv2_tiny",
|
566 |
+
]:
|
567 |
+
self.multi_scale_features.append(
|
568 |
+
output.permute(0, 3, 1, 2).contiguous()
|
569 |
+
) ### Swin, Swinv2
|
570 |
+
if self.model in [
|
571 |
+
"convnext_base",
|
572 |
+
"convnext_small",
|
573 |
+
"convnext_tiny",
|
574 |
+
"mobilenet",
|
575 |
+
"efficientnet",
|
576 |
+
]:
|
577 |
+
self.multi_scale_features.append(
|
578 |
+
output
|
579 |
+
) ### ConvNext, ResNet, EfficientNet, MobileNet
|
580 |
+
|
581 |
+
return hook
|
582 |
+
|
583 |
+
def forward(self, x):
|
584 |
+
self.multi_scale_features.clear()
|
585 |
+
|
586 |
+
_, _, h, w = x.shape
|
587 |
+
features = self.backbone(x).squeeze()
|
588 |
+
|
589 |
+
batch_size = self.multi_scale_features[-1].shape[0]
|
590 |
+
all_hidden_states = ()
|
591 |
+
for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
|
592 |
+
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
|
593 |
+
encoder_hidden_state = mlp(encoder_hidden_state)
|
594 |
+
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
|
595 |
+
encoder_hidden_state = encoder_hidden_state.reshape(
|
596 |
+
batch_size, -1, height, width
|
597 |
+
)
|
598 |
+
# upsample
|
599 |
+
encoder_hidden_state = nn.functional.interpolate(
|
600 |
+
encoder_hidden_state,
|
601 |
+
size=self.multi_scale_features[0].size()[2:],
|
602 |
+
mode="bilinear",
|
603 |
+
align_corners=False,
|
604 |
+
)
|
605 |
+
all_hidden_states += (encoder_hidden_state,)
|
606 |
+
|
607 |
+
fused_states = self.linear_fuse(
|
608 |
+
torch.cat(all_hidden_states[::-1], dim=1)
|
609 |
+
) #### torch.Size([BS, 256, 128, 128])
|
610 |
+
image_pe = self.pe_layer(
|
611 |
+
(fused_states.shape[2], fused_states.shape[3])
|
612 |
+
).unsqueeze(0)
|
613 |
+
seg_output = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe)
|
614 |
+
|
615 |
+
return seg_output
|
616 |
+
|
617 |
+
|
618 |
+
# 模型和配置初始化封装类
|
619 |
+
class ImageGenerator:
|
620 |
+
def __init__(self):
|
621 |
+
self.args = self.get_args()
|
622 |
+
self.pipeline, self.moe_model = self.get_model(self.args)
|
623 |
+
with open(self.args.config_path, "r") as f:
|
624 |
+
self.model_config = yaml.safe_load(f)["model"]
|
625 |
+
self.farl = facer.face_parser(
|
626 |
+
"farl/celebm/448",
|
627 |
+
self.args.device,
|
628 |
+
model_path="https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt",
|
629 |
+
)
|
630 |
+
self.segface = SegFaceCeleb(512, "convnext_base").to(self.args.device)
|
631 |
+
checkpoint = torch.hub.load_state_dict_from_url("https://huggingface.co/kartiknarayan/SegFace/resolve/main/convnext_celeba_512/model_299.pt")
|
632 |
+
self.segface.load_state_dict(checkpoint["state_dict_backbone"])
|
633 |
+
self.segface.eval()
|
634 |
+
self.segface_transforms = torchvision.transforms.Compose(
|
635 |
+
[
|
636 |
+
torchvision.transforms.ToTensor(),
|
637 |
+
torchvision.transforms.Normalize(
|
638 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
639 |
+
),
|
640 |
+
]
|
641 |
+
)
|
642 |
+
|
643 |
+
self.seg_face_remap_dict = {
|
644 |
+
0: 0, 1: 17, 2: 1, 3: 18, 4: 9, 5: 8, 6: 7, 7: 6,
|
645 |
+
8: 5, 9: 4, 10: 2, 11: 10, 12: 12, 13: 11, 14: 13,
|
646 |
+
15: 3, 16: 14, 17: 15, 18: 16,
|
647 |
+
}
|
648 |
+
|
649 |
+
self.palette = np.array(
|
650 |
+
[
|
651 |
+
(0, 0, 0), (204, 0, 0), (76, 153, 0), (204, 204, 0),
|
652 |
+
(204, 0, 204), (51, 51, 255), (255, 204, 204), (0, 255, 255),
|
653 |
+
(255, 0, 0), (102, 51, 0), (102, 204, 0), (255, 255, 0),
|
654 |
+
(0, 0, 153), (0, 0, 204), (255, 51, 153), (0, 204, 204),
|
655 |
+
(0, 51, 0), (255, 153, 51), (0, 204, 0),
|
656 |
+
],
|
657 |
+
dtype=np.uint8,
|
658 |
+
)
|
659 |
+
|
660 |
+
self.org_labels = [
|
661 |
+
"background", "face", "nose", "eyeg", "le", "re", "lb", "rb",
|
662 |
+
"lr", "rr", "imouth", "ulip", "llip", "hair", "hat", "earr",
|
663 |
+
"neck_l", "neck", "cloth",
|
664 |
+
]
|
665 |
+
|
666 |
+
self.new_labels = [
|
667 |
+
"background", "neck", "face", "cloth", "rr", "lr", "rb", "lb",
|
668 |
+
"re", "le", "nose", "imouth", "llip", "ulip", "hair", "eyeg",
|
669 |
+
"hat", "earr", "neck_l",
|
670 |
+
]
|
671 |
+
|
672 |
+
@torch.no_grad()
|
673 |
+
def parse_face_with_farl(self, image):
|
674 |
+
image = image.resize((512, 512), Image.BICUBIC)
|
675 |
+
image_np = np.array(image)
|
676 |
+
image_pt = torch.tensor(image_np).to(self.args.device)
|
677 |
+
image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).float()
|
678 |
+
pred, _ = self.farl.net(image_pt / 255.0)
|
679 |
+
vis_seg_probs = pred.argmax(dim=1).detach().cpu().numpy()[0].astype(np.uint8)
|
680 |
+
remapped_mask = np.zeros_like(vis_seg_probs, dtype=np.uint8)
|
681 |
+
for i, pred_label in enumerate(self.new_labels):
|
682 |
+
if pred_label in self.org_labels:
|
683 |
+
remapped_mask[vis_seg_probs == i] = self.org_labels.index(pred_label)
|
684 |
+
vis_seg_probs = Image.fromarray(remapped_mask).convert("P")
|
685 |
+
vis_seg_probs.putpalette(self.palette.flatten())
|
686 |
+
return vis_seg_probs
|
687 |
+
|
688 |
+
@torch.no_grad()
|
689 |
+
def parse_face_with_segface(self, image):
|
690 |
+
image = image.resize((512, 512), Image.BICUBIC)
|
691 |
+
image = self.segface_transforms(image)
|
692 |
+
logits = self.segface(image.unsqueeze(0).to(self.args.device))
|
693 |
+
vis_seg_probs = logits.argmax(dim=1).detach().cpu().numpy()[0].astype(np.uint8)
|
694 |
+
new_mask = np.zeros_like(vis_seg_probs)
|
695 |
+
for old_idx, new_idx in self.seg_face_remap_dict.items():
|
696 |
+
new_mask[vis_seg_probs == old_idx] = new_idx
|
697 |
+
vis_seg_probs = Image.fromarray(new_mask).convert("P")
|
698 |
+
vis_seg_probs.putpalette(self.palette.flatten())
|
699 |
+
return vis_seg_probs
|
700 |
+
|
701 |
+
def get_args(self):
|
702 |
+
class Args:
|
703 |
+
pipe = "black-forest-labs/FLUX.1-dev"
|
704 |
+
lora_ckpt = "weights"
|
705 |
+
moe_ckpt = "weights/mogle.pt"
|
706 |
+
pretrained_ckpt = "weights/FLUX.1-dev"
|
707 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
708 |
+
size = 512
|
709 |
+
seed = 42
|
710 |
+
config_path = "config/Face-MoGLE.yaml"
|
711 |
+
return Args()
|
712 |
+
|
713 |
+
def get_model(self, args):
|
714 |
+
pipeline = FluxPipeline.from_pretrained(
|
715 |
+
args.pretrained_ckpt, torch_dtype=torch.bfloat16
|
716 |
+
)
|
717 |
+
pipeline.load_lora_weights(args.lora_ckpt, weight_name=f"pytorch_lora_weights.safetensors",)
|
718 |
+
pipeline.to(args.device)
|
719 |
+
moe_model = MoGLE()
|
720 |
+
moe_weight = torch.load(args.moe_ckpt, map_location="cpu")
|
721 |
+
moe_model.load_state_dict(moe_weight, strict=True)
|
722 |
+
moe_model = moe_model.to(device=args.device, dtype=torch.bfloat16)
|
723 |
+
moe_model.eval()
|
724 |
+
return pipeline, moe_model
|
725 |
+
|
726 |
+
def pack_data(self, mask_image: Image.Image):
|
727 |
+
mask = np.array(mask_image.convert("L"))
|
728 |
+
mask_list = [T.ToTensor()(mask_image.convert("RGB"))]
|
729 |
+
for i in range(19):
|
730 |
+
local_mask = np.zeros_like(mask)
|
731 |
+
local_mask[mask == i] = 255
|
732 |
+
local_mask_tensor = T.ToTensor()(Image.fromarray(local_mask).convert("RGB"))
|
733 |
+
mask_list.append(local_mask_tensor)
|
734 |
+
condition_img = torch.stack(mask_list, dim=0)
|
735 |
+
return Condition(condition_type="depth", condition=condition_img, position_delta=[0, 0])
|
736 |
+
|
737 |
+
def generate(self, prompt: str, mask_image: Image.Image, seed: int, num_inference_steps=28):
|
738 |
+
generator = torch.Generator().manual_seed(seed)
|
739 |
+
condition = self.pack_data(mask_image)
|
740 |
+
result = generate(
|
741 |
+
self.pipeline,
|
742 |
+
mogle=self.moe_model,
|
743 |
+
prompt=prompt,
|
744 |
+
conditions=[condition],
|
745 |
+
height=self.args.size,
|
746 |
+
width=self.args.size,
|
747 |
+
generator=generator,
|
748 |
+
model_config=self.model_config,
|
749 |
+
default_lora=True,
|
750 |
+
num_inference_steps=num_inference_steps
|
751 |
+
)
|
752 |
+
return result.images[0]
|
753 |
+
|
754 |
+
|
755 |
+
# 实例化生成器
|
756 |
+
generator = ImageGenerator()
|
757 |
+
|
758 |
+
examples = [
|
759 |
+
|
760 |
+
["", "assets/mask2face/handou_seg.png", None, "FaRL", 42, 28],
|
761 |
+
|
762 |
+
["", "assets/mask2face/black_seg.png", None, "FaRL", 42, 28],
|
763 |
+
|
764 |
+
["She has red hair", "assets/multimodal/liuyifei_seg.png", None, "FaRL", 42, 28],
|
765 |
+
|
766 |
+
["He is old", "assets/multimodal/musk_seg.png", None, "FaRL", 42, 28],
|
767 |
+
|
768 |
+
["Curly-haired woman with glasses", None, None, "FaRL", 42, 28],
|
769 |
+
|
770 |
+
["Man with beard and tie", None, None, "FaRL", 42, 28],
|
771 |
+
|
772 |
+
]
|
773 |
+
|
774 |
+
# Gradio 界面(使用 Blocks)
|
775 |
+
with gr.Blocks(title="Controllable Face Generation with MoGLE") as demo:
|
776 |
+
gr.Markdown("## 🎭 Controllable Face Generation via Prompt + Face Parsing")
|
777 |
+
|
778 |
+
with gr.Row():
|
779 |
+
prompt = gr.Textbox(label="Text Prompt", placeholder="Describe the face you'd like to generate...")
|
780 |
+
|
781 |
+
with gr.Row():
|
782 |
+
with gr.Column():
|
783 |
+
mask_image = gr.Image(type="pil", label="🧩 Segmantic Mask (Optional)")
|
784 |
+
rgb_image = gr.Image(type="pil", label="🖼️ Facial Image (Optional)")
|
785 |
+
model_choice = gr.Radio(["FaRL", "SegFace"], label="Face Parsing Model", value="FaRL")
|
786 |
+
seed = gr.Slider(minimum=0, maximum=100000, step=1, value=42, label="Random Seed")
|
787 |
+
num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=28, label="Sampling Step")
|
788 |
+
submit_btn = gr.Button("Generate")
|
789 |
+
|
790 |
+
with gr.Column():
|
791 |
+
gr.Markdown("### 🧠 Parsed Mask Preview")
|
792 |
+
preview_mask = gr.Image(label="Parsed Mask (from RGB)", interactive=False)
|
793 |
+
output_image = gr.Image(label="🎨 Generated Image")
|
794 |
+
|
795 |
+
def generate_wrapper(prompt, mask_image, rgb_image, model_choice, seed,num_inference_steps):
|
796 |
+
if mask_image is None and rgb_image is not None:
|
797 |
+
if model_choice == "FaRL":
|
798 |
+
mask_image = generator.parse_face_with_farl(rgb_image)
|
799 |
+
else:
|
800 |
+
mask_image = generator.parse_face_with_segface(rgb_image)
|
801 |
+
elif mask_image is None and rgb_image is None:
|
802 |
+
# raise gr.Error("请上传至少一个:语义分割图 或 RGB 人脸图像。")
|
803 |
+
mask_image = Image.new("RGB", size=(512, 512))
|
804 |
+
return mask_image, generator.generate(prompt, mask_image, seed,num_inference_steps)
|
805 |
+
|
806 |
+
submit_btn.click(
|
807 |
+
fn=generate_wrapper,
|
808 |
+
inputs=[prompt, mask_image, rgb_image, model_choice, seed,num_inference_steps],
|
809 |
+
outputs=[preview_mask, output_image]
|
810 |
+
)
|
811 |
+
gr.Examples(
|
812 |
+
examples=examples,
|
813 |
+
inputs=[prompt, mask_image, rgb_image, model_choice, seed, num_inference_steps],
|
814 |
+
outputs=[preview_mask, output_image],
|
815 |
+
fn=lambda *args: generate_wrapper(*args), # 直接引用已定义的函数
|
816 |
+
cache_examples=False,
|
817 |
+
label="Click any example below to try:"
|
818 |
+
)
|
819 |
+
|
820 |
+
if __name__ == "__main__":
|
821 |
+
demo.launch(server_name="0.0.0.0", server_port=5000, share=False)
|
assets/mask2face/black_seg.png
ADDED
![]() |
assets/mask2face/handou_seg.png
ADDED
![]() |
assets/multimodal/liuyifei_seg.png
ADDED
![]() |
assets/multimodal/musk_seg.png
ADDED
![]() |
config/Face-MoGLE.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
sd_path: "checkpoints/FLUX.1-dev"
|
3 |
+
dtype: "bfloat16"
|
4 |
+
|
5 |
+
model:
|
6 |
+
union_cond_attn: true
|
7 |
+
add_cond_attn: false
|
8 |
+
latent_lora: false
|
9 |
+
|
10 |
+
train:
|
11 |
+
batch_size: 4
|
12 |
+
accumulate_grad_batches: 1
|
13 |
+
dataloader_workers: 4
|
14 |
+
save_interval: 1000
|
15 |
+
sample_interval: 100
|
16 |
+
max_steps: -1
|
17 |
+
gradient_checkpointing: true
|
18 |
+
save_path: "runs/face-mogle"
|
19 |
+
|
20 |
+
condition_type: "depth"
|
21 |
+
dataset:
|
22 |
+
root: "data/mmcelebahq"
|
23 |
+
condition_size: 512
|
24 |
+
target_size: 512
|
25 |
+
drop_text_prob: 0.1
|
26 |
+
drop_image_prob: 0.1
|
27 |
+
|
28 |
+
lora_config:
|
29 |
+
r: 4
|
30 |
+
lora_alpha: 4
|
31 |
+
init_lora_weights: "gaussian"
|
32 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
33 |
+
|
34 |
+
optimizer:
|
35 |
+
type: "Prodigy"
|
36 |
+
params:
|
37 |
+
lr: 1
|
38 |
+
use_bias_correction: true
|
39 |
+
safeguard_warmup: true
|
40 |
+
weight_decay: 0.01
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.31.0
|
2 |
+
transformers
|
3 |
+
peft
|
4 |
+
opencv-python
|
5 |
+
protobuf
|
6 |
+
sentencepiece
|
7 |
+
gradio
|
8 |
+
jupyter
|
9 |
+
torchao
|
10 |
+
pyfacer
|
11 |
+
yaml
|
12 |
+
|
13 |
+
lightning
|
14 |
+
datasets
|
15 |
+
torchvision
|
16 |
+
prodigyopt
|
17 |
+
wandb
|
src/flux/__pycache__/block.cpython-311.pyc
ADDED
Binary file (14.6 kB). View file
|
|
src/flux/__pycache__/condition.cpython-311.pyc
ADDED
Binary file (5.74 kB). View file
|
|
src/flux/__pycache__/generate.cpython-311.pyc
ADDED
Binary file (11.9 kB). View file
|
|
src/flux/__pycache__/lora_controller.cpython-311.pyc
ADDED
Binary file (5.12 kB). View file
|
|
src/flux/__pycache__/pipeline_tools.cpython-311.pyc
ADDED
Binary file (2.56 kB). View file
|
|
src/flux/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (7.85 kB). View file
|
|
src/flux/block.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
3 |
+
from diffusers.models.attention_processor import Attention, F
|
4 |
+
from .lora_controller import enable_lora
|
5 |
+
|
6 |
+
|
7 |
+
def attn_forward(
|
8 |
+
attn: Attention,
|
9 |
+
hidden_states: torch.FloatTensor,
|
10 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
11 |
+
condition_latents: torch.FloatTensor = None,
|
12 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
13 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
14 |
+
cond_rotary_emb: Optional[torch.Tensor] = None,
|
15 |
+
model_config: Optional[Dict[str, Any]] = {},
|
16 |
+
) -> torch.FloatTensor:
|
17 |
+
batch_size, _, _ = (
|
18 |
+
hidden_states.shape
|
19 |
+
if encoder_hidden_states is None
|
20 |
+
else encoder_hidden_states.shape
|
21 |
+
)
|
22 |
+
|
23 |
+
with enable_lora(
|
24 |
+
(attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
|
25 |
+
):
|
26 |
+
# `sample` projections.
|
27 |
+
query = attn.to_q(hidden_states)
|
28 |
+
key = attn.to_k(hidden_states)
|
29 |
+
value = attn.to_v(hidden_states)
|
30 |
+
# print(query.shape,key.shape,value.shape) torch.Size([2, 1024, 3072]) torch.Size([2, 1024, 3072]) torch.Size([2, 1024, 3072])
|
31 |
+
|
32 |
+
inner_dim = key.shape[-1]
|
33 |
+
head_dim = inner_dim // attn.heads
|
34 |
+
|
35 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
36 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
37 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
38 |
+
|
39 |
+
if attn.norm_q is not None:
|
40 |
+
query = attn.norm_q(query)
|
41 |
+
if attn.norm_k is not None:
|
42 |
+
key = attn.norm_k(key)
|
43 |
+
|
44 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
45 |
+
if encoder_hidden_states is not None:
|
46 |
+
# print(hidden_states.shape,encoder_hidden_states.shape,condition_latents.shape) torch.Size([2, 1024, 3072]) torch.Size([2, 512, 3072]) torch.Size([2, 1024, 3072])
|
47 |
+
# `context` projections.
|
48 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
49 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
50 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
51 |
+
|
52 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
53 |
+
batch_size, -1, attn.heads, head_dim
|
54 |
+
).transpose(1, 2)
|
55 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
56 |
+
batch_size, -1, attn.heads, head_dim
|
57 |
+
).transpose(1, 2)
|
58 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
59 |
+
batch_size, -1, attn.heads, head_dim
|
60 |
+
).transpose(1, 2)
|
61 |
+
|
62 |
+
if attn.norm_added_q is not None:
|
63 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(
|
64 |
+
encoder_hidden_states_query_proj
|
65 |
+
)
|
66 |
+
if attn.norm_added_k is not None:
|
67 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(
|
68 |
+
encoder_hidden_states_key_proj
|
69 |
+
)
|
70 |
+
# attention
|
71 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
72 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
73 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
74 |
+
|
75 |
+
if image_rotary_emb is not None:
|
76 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
77 |
+
|
78 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
79 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
80 |
+
|
81 |
+
if condition_latents is not None:
|
82 |
+
cond_query = attn.to_q(condition_latents)
|
83 |
+
cond_key = attn.to_k(condition_latents)
|
84 |
+
cond_value = attn.to_v(condition_latents)
|
85 |
+
|
86 |
+
cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
|
87 |
+
1, 2
|
88 |
+
)
|
89 |
+
cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
90 |
+
cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
|
91 |
+
1, 2
|
92 |
+
)
|
93 |
+
if attn.norm_q is not None:
|
94 |
+
cond_query = attn.norm_q(cond_query)
|
95 |
+
if attn.norm_k is not None:
|
96 |
+
cond_key = attn.norm_k(cond_key)
|
97 |
+
|
98 |
+
if cond_rotary_emb is not None:
|
99 |
+
cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
|
100 |
+
cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
|
101 |
+
|
102 |
+
if condition_latents is not None:
|
103 |
+
query = torch.cat([query, cond_query], dim=2)
|
104 |
+
key = torch.cat([key, cond_key], dim=2)
|
105 |
+
value = torch.cat([value, cond_value], dim=2)
|
106 |
+
|
107 |
+
if not model_config.get("union_cond_attn", True):
|
108 |
+
# If we don't want to use the union condition attention, we need to mask the attention
|
109 |
+
# between the hidden states and the condition latents
|
110 |
+
attention_mask = torch.ones(
|
111 |
+
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
|
112 |
+
)
|
113 |
+
condition_n = cond_query.shape[2]
|
114 |
+
attention_mask[-condition_n:, :-condition_n] = False
|
115 |
+
attention_mask[:-condition_n, -condition_n:] = False
|
116 |
+
elif model_config.get("independent_condition", False):
|
117 |
+
attention_mask = torch.ones(
|
118 |
+
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
|
119 |
+
)
|
120 |
+
condition_n = cond_query.shape[2]
|
121 |
+
attention_mask[-condition_n:, :-condition_n] = False
|
122 |
+
if hasattr(attn, "c_factor"):
|
123 |
+
attention_mask = torch.zeros(
|
124 |
+
query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
|
125 |
+
)
|
126 |
+
condition_n = cond_query.shape[2]
|
127 |
+
bias = torch.log(attn.c_factor[0])
|
128 |
+
attention_mask[-condition_n:, :-condition_n] = bias
|
129 |
+
attention_mask[:-condition_n, -condition_n:] = bias
|
130 |
+
hidden_states = F.scaled_dot_product_attention(
|
131 |
+
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
|
132 |
+
)
|
133 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
134 |
+
batch_size, -1, attn.heads * head_dim
|
135 |
+
)
|
136 |
+
# print(f"hidden_states {hidden_states.shape}")
|
137 |
+
hidden_states = hidden_states.to(query.dtype)
|
138 |
+
|
139 |
+
if encoder_hidden_states is not None:
|
140 |
+
if condition_latents is not None:
|
141 |
+
encoder_hidden_states, hidden_states, condition_latents = (
|
142 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
143 |
+
hidden_states[
|
144 |
+
:, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
|
145 |
+
],
|
146 |
+
hidden_states[:, -condition_latents.shape[1] :],
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
encoder_hidden_states, hidden_states = (
|
150 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
151 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
152 |
+
)
|
153 |
+
|
154 |
+
with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
|
155 |
+
# linear proj
|
156 |
+
hidden_states = attn.to_out[0](hidden_states)
|
157 |
+
# dropout
|
158 |
+
hidden_states = attn.to_out[1](hidden_states)
|
159 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
160 |
+
|
161 |
+
if condition_latents is not None:
|
162 |
+
condition_latents = attn.to_out[0](condition_latents)
|
163 |
+
condition_latents = attn.to_out[1](condition_latents)
|
164 |
+
|
165 |
+
return (
|
166 |
+
(hidden_states, encoder_hidden_states, condition_latents)
|
167 |
+
if condition_latents is not None
|
168 |
+
else (hidden_states, encoder_hidden_states)
|
169 |
+
)
|
170 |
+
elif condition_latents is not None:
|
171 |
+
# if there are condition_latents, we need to separate the hidden_states and the condition_latents
|
172 |
+
hidden_states, condition_latents = (
|
173 |
+
hidden_states[:, : -condition_latents.shape[1]],
|
174 |
+
hidden_states[:, -condition_latents.shape[1] :],
|
175 |
+
)
|
176 |
+
# print(hidden_states.shape,condition_latents.shape) torch.Size([2, 1536, 3072]) torch.Size([2, 1024, 3072])
|
177 |
+
return hidden_states, condition_latents
|
178 |
+
else:
|
179 |
+
return hidden_states
|
180 |
+
|
181 |
+
|
182 |
+
def block_forward(
|
183 |
+
self,
|
184 |
+
hidden_states: torch.FloatTensor,
|
185 |
+
encoder_hidden_states: torch.FloatTensor,
|
186 |
+
condition_latents: torch.FloatTensor,
|
187 |
+
temb: torch.FloatTensor,
|
188 |
+
cond_temb: torch.FloatTensor,
|
189 |
+
cond_rotary_emb=None,
|
190 |
+
image_rotary_emb=None,
|
191 |
+
model_config: Optional[Dict[str, Any]] = {},
|
192 |
+
):
|
193 |
+
use_cond = condition_latents is not None
|
194 |
+
with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
|
195 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
196 |
+
hidden_states, emb=temb
|
197 |
+
)
|
198 |
+
|
199 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
200 |
+
self.norm1_context(encoder_hidden_states, emb=temb)
|
201 |
+
)
|
202 |
+
# print(norm_encoder_hidden_states.shape,c_gate_msa.shape,c_shift_mlp.shape,c_scale_mlp.shape,c_gate_mlp.shape)
|
203 |
+
# torch.Size([2, 512, 3072]) torch.Size([2, 3072]) torch.Size([2, 3072]) torch.Size([2, 3072]) torch.Size([2, 3072])
|
204 |
+
|
205 |
+
if use_cond:
|
206 |
+
(
|
207 |
+
norm_condition_latents,
|
208 |
+
cond_gate_msa,
|
209 |
+
cond_shift_mlp,
|
210 |
+
cond_scale_mlp,
|
211 |
+
cond_gate_mlp,
|
212 |
+
) = self.norm1(condition_latents, emb=cond_temb)
|
213 |
+
|
214 |
+
# Attention.
|
215 |
+
result = attn_forward(
|
216 |
+
self.attn,
|
217 |
+
model_config=model_config,
|
218 |
+
hidden_states=norm_hidden_states,
|
219 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
220 |
+
condition_latents=norm_condition_latents if use_cond else None,
|
221 |
+
image_rotary_emb=image_rotary_emb,
|
222 |
+
cond_rotary_emb=cond_rotary_emb if use_cond else None,
|
223 |
+
)
|
224 |
+
attn_output, context_attn_output = result[:2]
|
225 |
+
cond_attn_output = result[2] if use_cond else None
|
226 |
+
|
227 |
+
# Process attention outputs for the `hidden_states`.
|
228 |
+
# 1. hidden_states
|
229 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
230 |
+
# print(hidden_states.shape,attn_output.shape) torch.Size([2, 1024, 3072]) torch.Size([2, 1024, 3072])
|
231 |
+
hidden_states = hidden_states + attn_output
|
232 |
+
# 2. encoder_hidden_states
|
233 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
234 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
235 |
+
# 3. condition_latents
|
236 |
+
if use_cond:
|
237 |
+
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
|
238 |
+
condition_latents = condition_latents + cond_attn_output
|
239 |
+
if model_config.get("add_cond_attn", False):
|
240 |
+
hidden_states += cond_attn_output
|
241 |
+
|
242 |
+
# LayerNorm + MLP.
|
243 |
+
# 1. hidden_states
|
244 |
+
norm_hidden_states = self.norm2(hidden_states)
|
245 |
+
norm_hidden_states = (
|
246 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
247 |
+
)
|
248 |
+
# 2. encoder_hidden_states
|
249 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
250 |
+
norm_encoder_hidden_states = (
|
251 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
252 |
+
)
|
253 |
+
# 3. condition_latents
|
254 |
+
if use_cond:
|
255 |
+
norm_condition_latents = self.norm2(condition_latents)
|
256 |
+
norm_condition_latents = (
|
257 |
+
norm_condition_latents * (1 + cond_scale_mlp[:, None])
|
258 |
+
+ cond_shift_mlp[:, None]
|
259 |
+
)
|
260 |
+
|
261 |
+
# Feed-forward.
|
262 |
+
with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
|
263 |
+
# 1. hidden_states
|
264 |
+
ff_output = self.ff(norm_hidden_states)
|
265 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
266 |
+
# 2. encoder_hidden_states
|
267 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
268 |
+
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
|
269 |
+
# 3. condition_latents
|
270 |
+
if use_cond:
|
271 |
+
cond_ff_output = self.ff(norm_condition_latents)
|
272 |
+
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
|
273 |
+
|
274 |
+
# Process feed-forward outputs.
|
275 |
+
hidden_states = hidden_states + ff_output
|
276 |
+
encoder_hidden_states = encoder_hidden_states + context_ff_output
|
277 |
+
if use_cond:
|
278 |
+
condition_latents = condition_latents + cond_ff_output
|
279 |
+
|
280 |
+
# Clip to avoid overflow.
|
281 |
+
if encoder_hidden_states.dtype == torch.float16:
|
282 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
283 |
+
|
284 |
+
return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
|
285 |
+
|
286 |
+
|
287 |
+
def single_block_forward(
|
288 |
+
self,
|
289 |
+
hidden_states: torch.FloatTensor,
|
290 |
+
temb: torch.FloatTensor,
|
291 |
+
image_rotary_emb=None,
|
292 |
+
condition_latents: torch.FloatTensor = None,
|
293 |
+
cond_temb: torch.FloatTensor = None,
|
294 |
+
cond_rotary_emb=None,
|
295 |
+
model_config: Optional[Dict[str, Any]] = {},
|
296 |
+
):
|
297 |
+
|
298 |
+
using_cond = condition_latents is not None
|
299 |
+
residual = hidden_states
|
300 |
+
with enable_lora(
|
301 |
+
(
|
302 |
+
self.norm.linear,
|
303 |
+
self.proj_mlp,
|
304 |
+
),
|
305 |
+
model_config.get("latent_lora", False),
|
306 |
+
):
|
307 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
308 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
309 |
+
if using_cond:
|
310 |
+
residual_cond = condition_latents
|
311 |
+
norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
|
312 |
+
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
|
313 |
+
|
314 |
+
attn_output = attn_forward(
|
315 |
+
self.attn,
|
316 |
+
model_config=model_config,
|
317 |
+
hidden_states=norm_hidden_states,
|
318 |
+
image_rotary_emb=image_rotary_emb,
|
319 |
+
**(
|
320 |
+
{
|
321 |
+
"condition_latents": norm_condition_latents,
|
322 |
+
"cond_rotary_emb": cond_rotary_emb if using_cond else None,
|
323 |
+
}
|
324 |
+
if using_cond
|
325 |
+
else {}
|
326 |
+
),
|
327 |
+
)
|
328 |
+
if using_cond:
|
329 |
+
attn_output, cond_attn_output = attn_output
|
330 |
+
|
331 |
+
with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
|
332 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
333 |
+
gate = gate.unsqueeze(1)
|
334 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
335 |
+
hidden_states = residual + hidden_states
|
336 |
+
if using_cond:
|
337 |
+
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
|
338 |
+
cond_gate = cond_gate.unsqueeze(1)
|
339 |
+
condition_latents = cond_gate * self.proj_out(condition_latents)
|
340 |
+
condition_latents = residual_cond + condition_latents
|
341 |
+
|
342 |
+
if hidden_states.dtype == torch.float16:
|
343 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
344 |
+
|
345 |
+
return hidden_states if not using_cond else (hidden_states, condition_latents)
|
src/flux/condition.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Union, List, Tuple
|
3 |
+
from diffusers.pipelines import FluxPipeline
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from .pipeline_tools import encode_images
|
9 |
+
|
10 |
+
condition_dict = {
|
11 |
+
"depth": 0,
|
12 |
+
"canny": 1,
|
13 |
+
"subject": 4,
|
14 |
+
"coloring": 6,
|
15 |
+
"deblurring": 7,
|
16 |
+
"depth_pred": 8,
|
17 |
+
"fill": 9,
|
18 |
+
"sr": 10,
|
19 |
+
"cartoon": 11,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class Condition(object):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
condition_type: str,
|
27 |
+
raw_img: Union[Image.Image, torch.Tensor] = None,
|
28 |
+
condition: Union[Image.Image, torch.Tensor] = None,
|
29 |
+
mask=None,
|
30 |
+
position_delta=None,
|
31 |
+
position_scale=1.0,
|
32 |
+
) -> None:
|
33 |
+
self.condition_type = condition_type
|
34 |
+
assert raw_img is not None or condition is not None
|
35 |
+
if raw_img is not None:
|
36 |
+
self.condition = self.get_condition(condition_type, raw_img)
|
37 |
+
else:
|
38 |
+
self.condition = condition
|
39 |
+
self.position_delta = position_delta
|
40 |
+
self.position_scale = position_scale
|
41 |
+
# TODO: Add mask support
|
42 |
+
assert mask is None, "Mask not supported yet"
|
43 |
+
|
44 |
+
def get_condition(
|
45 |
+
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
|
46 |
+
) -> Union[Image.Image, torch.Tensor]:
|
47 |
+
"""
|
48 |
+
Returns the condition image.
|
49 |
+
"""
|
50 |
+
if condition_type == "depth":
|
51 |
+
from transformers import pipeline
|
52 |
+
|
53 |
+
depth_pipe = pipeline(
|
54 |
+
task="depth-estimation",
|
55 |
+
model="LiheYoung/depth-anything-small-hf",
|
56 |
+
device="cuda",
|
57 |
+
)
|
58 |
+
source_image = raw_img.convert("RGB")
|
59 |
+
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
|
60 |
+
return condition_img
|
61 |
+
elif condition_type == "canny":
|
62 |
+
img = np.array(raw_img)
|
63 |
+
edges = cv2.Canny(img, 100, 200)
|
64 |
+
edges = Image.fromarray(edges).convert("RGB")
|
65 |
+
return edges
|
66 |
+
elif condition_type == "subject":
|
67 |
+
return raw_img
|
68 |
+
elif condition_type == "coloring":
|
69 |
+
return raw_img.convert("L").convert("RGB")
|
70 |
+
elif condition_type == "deblurring":
|
71 |
+
condition_image = (
|
72 |
+
raw_img.convert("RGB")
|
73 |
+
.filter(ImageFilter.GaussianBlur(10))
|
74 |
+
.convert("RGB")
|
75 |
+
)
|
76 |
+
return condition_image
|
77 |
+
elif condition_type == "fill":
|
78 |
+
return raw_img.convert("RGB")
|
79 |
+
elif condition_type == "cartoon":
|
80 |
+
return raw_img.convert("RGB")
|
81 |
+
return self.condition
|
82 |
+
|
83 |
+
@property
|
84 |
+
def type_id(self) -> int:
|
85 |
+
"""
|
86 |
+
Returns the type id of the condition.
|
87 |
+
"""
|
88 |
+
return condition_dict[self.condition_type]
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def get_type_id(cls, condition_type: str) -> int:
|
92 |
+
"""
|
93 |
+
Returns the type id of the condition.
|
94 |
+
"""
|
95 |
+
return condition_dict[condition_type]
|
96 |
+
|
97 |
+
def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
98 |
+
"""
|
99 |
+
Encodes the condition into tokens, ids and type_id.
|
100 |
+
"""
|
101 |
+
if self.condition_type in [
|
102 |
+
"depth",
|
103 |
+
"canny",
|
104 |
+
"subject",
|
105 |
+
"coloring",
|
106 |
+
"deblurring",
|
107 |
+
"depth_pred",
|
108 |
+
"fill",
|
109 |
+
"sr",
|
110 |
+
"cartoon",
|
111 |
+
]:
|
112 |
+
tokens, ids = encode_images(pipe, self.condition)
|
113 |
+
else:
|
114 |
+
raise NotImplementedError(
|
115 |
+
f"Condition type {self.condition_type} not implemented"
|
116 |
+
)
|
117 |
+
if self.position_delta is None and self.condition_type == "subject":
|
118 |
+
self.position_delta = [0, -self.condition.size[0] // 16]
|
119 |
+
if self.position_delta is not None:
|
120 |
+
ids[:, 1] += self.position_delta[0]
|
121 |
+
ids[:, 2] += self.position_delta[1]
|
122 |
+
if self.position_scale != 1.0:
|
123 |
+
scale_bias = (self.position_scale - 1.0) / 2
|
124 |
+
ids[:, 1] *= self.position_scale
|
125 |
+
ids[:, 2] *= self.position_scale
|
126 |
+
ids[:, 1] += scale_bias
|
127 |
+
ids[:, 2] += scale_bias
|
128 |
+
type_id = torch.ones_like(ids[:, :1]) * self.type_id
|
129 |
+
return tokens, ids, type_id
|
src/flux/generate.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import yaml, os
|
3 |
+
from diffusers.pipelines import FluxPipeline, StableDiffusionPipeline
|
4 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
5 |
+
rescale_noise_cfg,
|
6 |
+
)
|
7 |
+
from diffusers.utils import deprecate, is_torch_xla_available
|
8 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
9 |
+
StableDiffusionPipelineOutput,
|
10 |
+
)
|
11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
12 |
+
from torchvision import transforms as T
|
13 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
14 |
+
from .transformer import tranformer_forward
|
15 |
+
from .condition import Condition
|
16 |
+
|
17 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
18 |
+
FluxPipelineOutput,
|
19 |
+
calculate_shift,
|
20 |
+
retrieve_timesteps,
|
21 |
+
np,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def get_config(config_path: str = None):
|
26 |
+
config_path = config_path or os.environ.get("XFL_CONFIG")
|
27 |
+
if not config_path:
|
28 |
+
return {}
|
29 |
+
with open(config_path, "r") as f:
|
30 |
+
config = yaml.safe_load(f)
|
31 |
+
return config
|
32 |
+
|
33 |
+
|
34 |
+
def prepare_params(
|
35 |
+
prompt: Union[str, List[str]] = None,
|
36 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
37 |
+
height: Optional[int] = 512,
|
38 |
+
width: Optional[int] = 512,
|
39 |
+
num_inference_steps: int = 28,
|
40 |
+
timesteps: List[int] = None,
|
41 |
+
guidance_scale: float = 3.5,
|
42 |
+
num_images_per_prompt: Optional[int] = 1,
|
43 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
44 |
+
latents: Optional[torch.FloatTensor] = None,
|
45 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
46 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
47 |
+
output_type: Optional[str] = "pil",
|
48 |
+
return_dict: bool = True,
|
49 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
50 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
51 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
52 |
+
max_sequence_length: int = 512,
|
53 |
+
**kwargs: dict,
|
54 |
+
):
|
55 |
+
return (
|
56 |
+
prompt,
|
57 |
+
prompt_2,
|
58 |
+
height,
|
59 |
+
width,
|
60 |
+
num_inference_steps,
|
61 |
+
timesteps,
|
62 |
+
guidance_scale,
|
63 |
+
num_images_per_prompt,
|
64 |
+
generator,
|
65 |
+
latents,
|
66 |
+
prompt_embeds,
|
67 |
+
pooled_prompt_embeds,
|
68 |
+
output_type,
|
69 |
+
return_dict,
|
70 |
+
joint_attention_kwargs,
|
71 |
+
callback_on_step_end,
|
72 |
+
callback_on_step_end_tensor_inputs,
|
73 |
+
max_sequence_length,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
def seed_everything(seed: int = 42):
|
78 |
+
torch.backends.cudnn.deterministic = True
|
79 |
+
torch.manual_seed(seed)
|
80 |
+
np.random.seed(seed)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
def generate(
|
90 |
+
pipeline: FluxPipeline,
|
91 |
+
mogle=None,
|
92 |
+
conditions: List[Condition] = None,
|
93 |
+
config_path: str = None,
|
94 |
+
model_config: Optional[Dict[str, Any]] = {},
|
95 |
+
condition_scale: float = 1.0,
|
96 |
+
default_lora: bool = False,
|
97 |
+
**params: dict,
|
98 |
+
):
|
99 |
+
model_config = model_config or get_config(config_path).get("model", {})
|
100 |
+
if condition_scale != 1:
|
101 |
+
for name, module in pipeline.transformer.named_modules():
|
102 |
+
if not name.endswith(".attn"):
|
103 |
+
continue
|
104 |
+
module.c_factor = torch.ones(1, 1) * condition_scale
|
105 |
+
|
106 |
+
self = pipeline
|
107 |
+
(
|
108 |
+
prompt,
|
109 |
+
prompt_2,
|
110 |
+
height,
|
111 |
+
width,
|
112 |
+
num_inference_steps,
|
113 |
+
timesteps,
|
114 |
+
guidance_scale,
|
115 |
+
num_images_per_prompt,
|
116 |
+
generator,
|
117 |
+
latents,
|
118 |
+
prompt_embeds,
|
119 |
+
pooled_prompt_embeds,
|
120 |
+
output_type,
|
121 |
+
return_dict,
|
122 |
+
joint_attention_kwargs,
|
123 |
+
callback_on_step_end,
|
124 |
+
callback_on_step_end_tensor_inputs,
|
125 |
+
max_sequence_length,
|
126 |
+
) = prepare_params(**params)
|
127 |
+
|
128 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
129 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
130 |
+
|
131 |
+
# 1. Check inputs. Raise error if not correct
|
132 |
+
self.check_inputs(
|
133 |
+
prompt,
|
134 |
+
prompt_2,
|
135 |
+
height,
|
136 |
+
width,
|
137 |
+
prompt_embeds=prompt_embeds,
|
138 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
139 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
140 |
+
max_sequence_length=max_sequence_length,
|
141 |
+
)
|
142 |
+
|
143 |
+
self._guidance_scale = guidance_scale
|
144 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
145 |
+
self._interrupt = False
|
146 |
+
|
147 |
+
# 2. Define call parameters
|
148 |
+
if prompt is not None and isinstance(prompt, str):
|
149 |
+
batch_size = 1
|
150 |
+
elif prompt is not None and isinstance(prompt, list):
|
151 |
+
batch_size = len(prompt)
|
152 |
+
else:
|
153 |
+
batch_size = prompt_embeds.shape[0]
|
154 |
+
|
155 |
+
device = self._execution_device
|
156 |
+
|
157 |
+
lora_scale = (
|
158 |
+
self.joint_attention_kwargs.get("scale", None)
|
159 |
+
if self.joint_attention_kwargs is not None
|
160 |
+
else None
|
161 |
+
)
|
162 |
+
(
|
163 |
+
prompt_embeds,
|
164 |
+
pooled_prompt_embeds,
|
165 |
+
text_ids,
|
166 |
+
) = self.encode_prompt(
|
167 |
+
prompt=prompt,
|
168 |
+
prompt_2=prompt_2,
|
169 |
+
prompt_embeds=prompt_embeds,
|
170 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
171 |
+
device=device,
|
172 |
+
num_images_per_prompt=num_images_per_prompt,
|
173 |
+
max_sequence_length=max_sequence_length,
|
174 |
+
lora_scale=lora_scale,
|
175 |
+
)
|
176 |
+
|
177 |
+
# 4. Prepare latent variables
|
178 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
179 |
+
latents, latent_image_ids = self.prepare_latents(
|
180 |
+
batch_size * num_images_per_prompt,
|
181 |
+
num_channels_latents,
|
182 |
+
height,
|
183 |
+
width,
|
184 |
+
prompt_embeds.dtype,
|
185 |
+
device,
|
186 |
+
generator,
|
187 |
+
latents,
|
188 |
+
)
|
189 |
+
|
190 |
+
# 4.1. Prepare conditions
|
191 |
+
condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
|
192 |
+
use_condition = conditions is not None or []
|
193 |
+
if use_condition:
|
194 |
+
assert len(conditions) <= 1, "Only one condition is supported for now."
|
195 |
+
if not default_lora:
|
196 |
+
pipeline.set_adapters(conditions[0].condition_type)
|
197 |
+
for condition in conditions:
|
198 |
+
tokens, ids, type_id = condition.encode(self)
|
199 |
+
#print(tokens.shape) # 20 1024 64
|
200 |
+
# bs, mask_num, channel, h, w = tokens.shape
|
201 |
+
tokens_reshape = tokens.reshape(1, -1, *tokens.shape[1:])
|
202 |
+
#print(tokens.shape) # 1 1024 64
|
203 |
+
condition_latents.append(tokens_reshape) # [batch_size, token_n, token_dim]
|
204 |
+
condition_ids.append(ids) # [token_n, id_dim(3)]
|
205 |
+
condition_type_ids.append(type_id) # [token_n, 1]
|
206 |
+
condition_latents = torch.cat(condition_latents, dim=1)
|
207 |
+
condition_ids = torch.cat(condition_ids, dim=0)
|
208 |
+
condition_type_ids = torch.cat(condition_type_ids, dim=0)
|
209 |
+
|
210 |
+
# 5. Prepare timesteps
|
211 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
212 |
+
image_seq_len = latents.shape[1]
|
213 |
+
mu = calculate_shift(
|
214 |
+
image_seq_len,
|
215 |
+
self.scheduler.config.base_image_seq_len,
|
216 |
+
self.scheduler.config.max_image_seq_len,
|
217 |
+
self.scheduler.config.base_shift,
|
218 |
+
self.scheduler.config.max_shift,
|
219 |
+
)
|
220 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
221 |
+
self.scheduler,
|
222 |
+
num_inference_steps,
|
223 |
+
device,
|
224 |
+
timesteps,
|
225 |
+
sigmas,
|
226 |
+
mu=mu,
|
227 |
+
)
|
228 |
+
num_warmup_steps = max(
|
229 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
230 |
+
)
|
231 |
+
self._num_timesteps = len(timesteps)
|
232 |
+
|
233 |
+
# 6. Denoising loop
|
234 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
235 |
+
for i, t in enumerate(timesteps):
|
236 |
+
if self.interrupt:
|
237 |
+
continue
|
238 |
+
cur_condition_latents = mogle(condition_latents,latents,t.expand(1))
|
239 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
240 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
241 |
+
|
242 |
+
# handle guidance
|
243 |
+
if self.transformer.config.guidance_embeds:
|
244 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
245 |
+
guidance = guidance.expand(latents.shape[0])
|
246 |
+
else:
|
247 |
+
guidance = None
|
248 |
+
noise_pred = tranformer_forward(
|
249 |
+
self.transformer,
|
250 |
+
model_config=model_config,
|
251 |
+
# Inputs of the condition (new feature)
|
252 |
+
condition_latents=cur_condition_latents if use_condition else None,
|
253 |
+
condition_ids=condition_ids if use_condition else None,
|
254 |
+
condition_type_ids=condition_type_ids if use_condition else None,
|
255 |
+
# Inputs to the original transformer
|
256 |
+
hidden_states=latents,
|
257 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
258 |
+
timestep=timestep / 1000,
|
259 |
+
guidance=guidance,
|
260 |
+
pooled_projections=pooled_prompt_embeds,
|
261 |
+
encoder_hidden_states=prompt_embeds,
|
262 |
+
txt_ids=text_ids,
|
263 |
+
img_ids=latent_image_ids,
|
264 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
265 |
+
return_dict=False,
|
266 |
+
)[0]
|
267 |
+
|
268 |
+
# compute the previous noisy sample x_t -> x_t-1
|
269 |
+
latents_dtype = latents.dtype
|
270 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
271 |
+
|
272 |
+
if latents.dtype != latents_dtype:
|
273 |
+
if torch.backends.mps.is_available():
|
274 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
275 |
+
latents = latents.to(latents_dtype)
|
276 |
+
|
277 |
+
if callback_on_step_end is not None:
|
278 |
+
callback_kwargs = {}
|
279 |
+
for k in callback_on_step_end_tensor_inputs:
|
280 |
+
callback_kwargs[k] = locals()[k]
|
281 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
282 |
+
|
283 |
+
latents = callback_outputs.pop("latents", latents)
|
284 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
285 |
+
|
286 |
+
# call the callback, if provided
|
287 |
+
if i == len(timesteps) - 1 or (
|
288 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
289 |
+
):
|
290 |
+
progress_bar.update()
|
291 |
+
|
292 |
+
if output_type == "latent":
|
293 |
+
image = latents
|
294 |
+
|
295 |
+
else:
|
296 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
297 |
+
latents = (
|
298 |
+
latents / self.vae.config.scaling_factor
|
299 |
+
) + self.vae.config.shift_factor
|
300 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
301 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
302 |
+
|
303 |
+
# Offload all models
|
304 |
+
self.maybe_free_model_hooks()
|
305 |
+
|
306 |
+
if condition_scale != 1:
|
307 |
+
for name, module in pipeline.transformer.named_modules():
|
308 |
+
if not name.endswith(".attn"):
|
309 |
+
continue
|
310 |
+
del module.c_factor
|
311 |
+
|
312 |
+
if not return_dict:
|
313 |
+
return (image,)
|
314 |
+
|
315 |
+
return FluxPipelineOutput(images=image)
|
316 |
+
|
src/flux/lora_controller.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
2 |
+
from typing import List, Any, Optional, Type
|
3 |
+
|
4 |
+
|
5 |
+
class enable_lora:
|
6 |
+
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
|
7 |
+
self.activated: bool = activated
|
8 |
+
if activated:
|
9 |
+
return
|
10 |
+
self.lora_modules: List[BaseTunerLayer] = [
|
11 |
+
each for each in lora_modules if isinstance(each, BaseTunerLayer)
|
12 |
+
]
|
13 |
+
self.scales = [
|
14 |
+
{
|
15 |
+
active_adapter: lora_module.scaling[active_adapter]
|
16 |
+
for active_adapter in lora_module.active_adapters
|
17 |
+
}
|
18 |
+
for lora_module in self.lora_modules
|
19 |
+
]
|
20 |
+
|
21 |
+
def __enter__(self) -> None:
|
22 |
+
if self.activated:
|
23 |
+
return
|
24 |
+
|
25 |
+
for lora_module in self.lora_modules:
|
26 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
27 |
+
continue
|
28 |
+
lora_module.scale_layer(0)
|
29 |
+
|
30 |
+
def __exit__(
|
31 |
+
self,
|
32 |
+
exc_type: Optional[Type[BaseException]],
|
33 |
+
exc_val: Optional[BaseException],
|
34 |
+
exc_tb: Optional[Any],
|
35 |
+
) -> None:
|
36 |
+
if self.activated:
|
37 |
+
return
|
38 |
+
for i, lora_module in enumerate(self.lora_modules):
|
39 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
40 |
+
continue
|
41 |
+
for active_adapter in lora_module.active_adapters:
|
42 |
+
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|
43 |
+
|
44 |
+
|
45 |
+
class set_lora_scale:
|
46 |
+
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
|
47 |
+
self.lora_modules: List[BaseTunerLayer] = [
|
48 |
+
each for each in lora_modules if isinstance(each, BaseTunerLayer)
|
49 |
+
]
|
50 |
+
self.scales = [
|
51 |
+
{
|
52 |
+
active_adapter: lora_module.scaling[active_adapter]
|
53 |
+
for active_adapter in lora_module.active_adapters
|
54 |
+
}
|
55 |
+
for lora_module in self.lora_modules
|
56 |
+
]
|
57 |
+
self.scale = scale
|
58 |
+
|
59 |
+
def __enter__(self) -> None:
|
60 |
+
for lora_module in self.lora_modules:
|
61 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
62 |
+
continue
|
63 |
+
lora_module.scale_layer(self.scale)
|
64 |
+
|
65 |
+
def __exit__(
|
66 |
+
self,
|
67 |
+
exc_type: Optional[Type[BaseException]],
|
68 |
+
exc_val: Optional[BaseException],
|
69 |
+
exc_tb: Optional[Any],
|
70 |
+
) -> None:
|
71 |
+
for i, lora_module in enumerate(self.lora_modules):
|
72 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
73 |
+
continue
|
74 |
+
for active_adapter in lora_module.active_adapters:
|
75 |
+
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|
src/flux/pipeline_tools.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.pipelines import FluxPipeline
|
2 |
+
from diffusers.utils import logging
|
3 |
+
from diffusers.pipelines.flux.pipeline_flux import logger
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
def encode_images(pipeline: FluxPipeline, images: Tensor):
|
8 |
+
images = pipeline.image_processor.preprocess(images)
|
9 |
+
images = images.to(pipeline.device).to(pipeline.dtype)
|
10 |
+
images = pipeline.vae.encode(images).latent_dist.sample()
|
11 |
+
images = (
|
12 |
+
images - pipeline.vae.config.shift_factor
|
13 |
+
) * pipeline.vae.config.scaling_factor
|
14 |
+
images_tokens = pipeline._pack_latents(images, *images.shape)
|
15 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
16 |
+
images.shape[0],
|
17 |
+
images.shape[2],
|
18 |
+
images.shape[3],
|
19 |
+
pipeline.device,
|
20 |
+
pipeline.dtype,
|
21 |
+
)
|
22 |
+
if images_tokens.shape[1] != images_ids.shape[0]:
|
23 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
24 |
+
images.shape[0],
|
25 |
+
images.shape[2] // 2,
|
26 |
+
images.shape[3] // 2,
|
27 |
+
pipeline.device,
|
28 |
+
pipeline.dtype,
|
29 |
+
)
|
30 |
+
return images_tokens, images_ids
|
31 |
+
|
32 |
+
|
33 |
+
def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
|
34 |
+
# Turn off warnings (CLIP overflow)
|
35 |
+
logger.setLevel(logging.ERROR)
|
36 |
+
(
|
37 |
+
prompt_embeds,
|
38 |
+
pooled_prompt_embeds,
|
39 |
+
text_ids,
|
40 |
+
) = pipeline.encode_prompt(
|
41 |
+
prompt=prompts,
|
42 |
+
prompt_2=None,
|
43 |
+
prompt_embeds=None,
|
44 |
+
pooled_prompt_embeds=None,
|
45 |
+
device=pipeline.device,
|
46 |
+
num_images_per_prompt=1,
|
47 |
+
max_sequence_length=max_sequence_length,
|
48 |
+
lora_scale=None,
|
49 |
+
)
|
50 |
+
# Turn on warnings
|
51 |
+
logger.setLevel(logging.WARNING)
|
52 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
src/flux/transformer.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.pipelines import FluxPipeline
|
3 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
4 |
+
from .block import block_forward, single_block_forward
|
5 |
+
from .lora_controller import enable_lora
|
6 |
+
from accelerate.utils import is_torch_version
|
7 |
+
from diffusers.models.transformers.transformer_flux import (
|
8 |
+
FluxTransformer2DModel,
|
9 |
+
Transformer2DModelOutput,
|
10 |
+
USE_PEFT_BACKEND,
|
11 |
+
scale_lora_layers,
|
12 |
+
unscale_lora_layers,
|
13 |
+
logger,
|
14 |
+
)
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
def prepare_params(
|
19 |
+
hidden_states: torch.Tensor,
|
20 |
+
encoder_hidden_states: torch.Tensor = None,
|
21 |
+
pooled_projections: torch.Tensor = None,
|
22 |
+
timestep: torch.LongTensor = None,
|
23 |
+
img_ids: torch.Tensor = None,
|
24 |
+
txt_ids: torch.Tensor = None,
|
25 |
+
guidance: torch.Tensor = None,
|
26 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
27 |
+
controlnet_block_samples=None,
|
28 |
+
controlnet_single_block_samples=None,
|
29 |
+
return_dict: bool = True,
|
30 |
+
**kwargs: dict,
|
31 |
+
):
|
32 |
+
return (
|
33 |
+
hidden_states,
|
34 |
+
encoder_hidden_states,
|
35 |
+
pooled_projections,
|
36 |
+
timestep,
|
37 |
+
img_ids,
|
38 |
+
txt_ids,
|
39 |
+
guidance,
|
40 |
+
joint_attention_kwargs,
|
41 |
+
controlnet_block_samples,
|
42 |
+
controlnet_single_block_samples,
|
43 |
+
return_dict,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def tranformer_forward(
|
48 |
+
transformer: FluxTransformer2DModel,
|
49 |
+
condition_latents: torch.Tensor,
|
50 |
+
condition_ids: torch.Tensor,
|
51 |
+
condition_type_ids: torch.Tensor,
|
52 |
+
model_config: Optional[Dict[str, Any]] = {},
|
53 |
+
c_t=0,
|
54 |
+
**params: dict,
|
55 |
+
):
|
56 |
+
self = transformer
|
57 |
+
use_condition = condition_latents is not None
|
58 |
+
|
59 |
+
(
|
60 |
+
hidden_states,
|
61 |
+
encoder_hidden_states,
|
62 |
+
pooled_projections,
|
63 |
+
timestep,
|
64 |
+
img_ids,
|
65 |
+
txt_ids,
|
66 |
+
guidance,
|
67 |
+
joint_attention_kwargs,
|
68 |
+
controlnet_block_samples,
|
69 |
+
controlnet_single_block_samples,
|
70 |
+
return_dict,
|
71 |
+
) = prepare_params(**params)
|
72 |
+
|
73 |
+
if joint_attention_kwargs is not None:
|
74 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
75 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
76 |
+
else:
|
77 |
+
lora_scale = 1.0
|
78 |
+
|
79 |
+
if USE_PEFT_BACKEND:
|
80 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
81 |
+
scale_lora_layers(self, lora_scale)
|
82 |
+
else:
|
83 |
+
if (
|
84 |
+
joint_attention_kwargs is not None
|
85 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
86 |
+
):
|
87 |
+
logger.warning(
|
88 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
89 |
+
)
|
90 |
+
|
91 |
+
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
|
92 |
+
hidden_states = self.x_embedder(hidden_states)
|
93 |
+
# print("hidden states :",hidden_states.shape) hidden states : torch.Size([2, 1024, 3072])
|
94 |
+
condition_latents = self.x_embedder(condition_latents) if use_condition else None
|
95 |
+
# print(f"condition_latents shape {condition_latents.shape}") condition_latents shape torch.Size([2, 1024, 3072])
|
96 |
+
|
97 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
98 |
+
|
99 |
+
if guidance is not None:
|
100 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
101 |
+
else:
|
102 |
+
guidance = None
|
103 |
+
|
104 |
+
temb = (
|
105 |
+
self.time_text_embed(timestep, pooled_projections)
|
106 |
+
if guidance is None
|
107 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
108 |
+
)
|
109 |
+
# print(f"temb shape:{temb.shape}") torch.Size([2, 3072])
|
110 |
+
|
111 |
+
cond_temb = (
|
112 |
+
self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
|
113 |
+
if guidance is None
|
114 |
+
else self.time_text_embed(
|
115 |
+
torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
|
116 |
+
)
|
117 |
+
)
|
118 |
+
# print("cond temb shape",cond_temb.shape) torch.Size([2, 3072])
|
119 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
120 |
+
# print(f"encoder hidden states {encoder_hidden_states.shape}") torch.Size([2, 512, 3072])
|
121 |
+
|
122 |
+
if txt_ids.ndim == 3:
|
123 |
+
logger.warning(
|
124 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
125 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
126 |
+
)
|
127 |
+
txt_ids = txt_ids[0]
|
128 |
+
if img_ids.ndim == 3:
|
129 |
+
logger.warning(
|
130 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
131 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
132 |
+
)
|
133 |
+
img_ids = img_ids[0]
|
134 |
+
|
135 |
+
ids = torch.cat((txt_ids, img_ids), dim=0) # 1536 3
|
136 |
+
image_rotary_emb = self.pos_embed(ids) # 2 1536 128
|
137 |
+
|
138 |
+
if use_condition:
|
139 |
+
# condition_ids[:, :1] = condition_type_ids
|
140 |
+
cond_rotary_emb = self.pos_embed(condition_ids) # 2 1536 128
|
141 |
+
# hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
|
142 |
+
|
143 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
144 |
+
if self.training and self.gradient_checkpointing:
|
145 |
+
ckpt_kwargs: Dict[str, Any] = (
|
146 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
147 |
+
)
|
148 |
+
encoder_hidden_states, hidden_states, condition_latents = (
|
149 |
+
torch.utils.checkpoint.checkpoint(
|
150 |
+
block_forward,
|
151 |
+
self=block,
|
152 |
+
model_config=model_config,
|
153 |
+
hidden_states=hidden_states,
|
154 |
+
encoder_hidden_states=encoder_hidden_states,
|
155 |
+
condition_latents=condition_latents if use_condition else None,
|
156 |
+
temb=temb,
|
157 |
+
cond_temb=cond_temb if use_condition else None,
|
158 |
+
cond_rotary_emb=cond_rotary_emb if use_condition else None,
|
159 |
+
image_rotary_emb=image_rotary_emb,
|
160 |
+
**ckpt_kwargs,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
|
164 |
+
else:
|
165 |
+
encoder_hidden_states, hidden_states, condition_latents = block_forward(
|
166 |
+
block,
|
167 |
+
model_config=model_config,
|
168 |
+
hidden_states=hidden_states,
|
169 |
+
encoder_hidden_states=encoder_hidden_states,
|
170 |
+
condition_latents=condition_latents if use_condition else None,
|
171 |
+
temb=temb,
|
172 |
+
cond_temb=cond_temb if use_condition else None,
|
173 |
+
cond_rotary_emb=cond_rotary_emb if use_condition else None,
|
174 |
+
image_rotary_emb=image_rotary_emb,
|
175 |
+
)
|
176 |
+
|
177 |
+
# controlnet residual
|
178 |
+
if controlnet_block_samples is not None:
|
179 |
+
interval_control = len(self.transformer_blocks) / len(
|
180 |
+
controlnet_block_samples
|
181 |
+
)
|
182 |
+
interval_control = int(np.ceil(interval_control))
|
183 |
+
hidden_states = (
|
184 |
+
hidden_states
|
185 |
+
+ controlnet_block_samples[index_block // interval_control]
|
186 |
+
)
|
187 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
188 |
+
|
189 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
190 |
+
if self.training and self.gradient_checkpointing:
|
191 |
+
ckpt_kwargs: Dict[str, Any] = (
|
192 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
193 |
+
)
|
194 |
+
result = torch.utils.checkpoint.checkpoint(
|
195 |
+
single_block_forward,
|
196 |
+
self=block,
|
197 |
+
model_config=model_config,
|
198 |
+
hidden_states=hidden_states,
|
199 |
+
temb=temb,
|
200 |
+
image_rotary_emb=image_rotary_emb,
|
201 |
+
**(
|
202 |
+
{
|
203 |
+
"condition_latents": condition_latents,
|
204 |
+
"cond_temb": cond_temb,
|
205 |
+
"cond_rotary_emb": cond_rotary_emb,
|
206 |
+
}
|
207 |
+
if use_condition
|
208 |
+
else {}
|
209 |
+
),
|
210 |
+
**ckpt_kwargs,
|
211 |
+
)
|
212 |
+
|
213 |
+
else:
|
214 |
+
result = single_block_forward(
|
215 |
+
block,
|
216 |
+
model_config=model_config,
|
217 |
+
hidden_states=hidden_states,
|
218 |
+
temb=temb,
|
219 |
+
image_rotary_emb=image_rotary_emb,
|
220 |
+
**(
|
221 |
+
{
|
222 |
+
"condition_latents": condition_latents,
|
223 |
+
"cond_temb": cond_temb,
|
224 |
+
"cond_rotary_emb": cond_rotary_emb,
|
225 |
+
}
|
226 |
+
if use_condition
|
227 |
+
else {}
|
228 |
+
),
|
229 |
+
)
|
230 |
+
if use_condition:
|
231 |
+
hidden_states, condition_latents = result
|
232 |
+
else:
|
233 |
+
hidden_states = result
|
234 |
+
|
235 |
+
# controlnet residual
|
236 |
+
if controlnet_single_block_samples is not None:
|
237 |
+
interval_control = len(self.single_transformer_blocks) / len(
|
238 |
+
controlnet_single_block_samples
|
239 |
+
)
|
240 |
+
interval_control = int(np.ceil(interval_control))
|
241 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
242 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
243 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
244 |
+
)
|
245 |
+
|
246 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
247 |
+
|
248 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
249 |
+
output = self.proj_out(hidden_states)
|
250 |
+
# print(f"output shape:{output.shape}")
|
251 |
+
if USE_PEFT_BACKEND:
|
252 |
+
# remove `lora_scale` from each PEFT layer
|
253 |
+
unscale_lora_layers(self, lora_scale)
|
254 |
+
|
255 |
+
if not return_dict:
|
256 |
+
return (output,)
|
257 |
+
return Transformer2DModelOutput(sample=output)
|
src/moe/__pycache__/mogle.cpython-311.pyc
ADDED
Binary file (7.47 kB). View file
|
|
src/moe/mogle.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
# Define the Expert Network
|
9 |
+
class Expert(nn.Module):
|
10 |
+
def __init__(self, input_dim, hidden_dim, output_dim, use_softmax=False):
|
11 |
+
super(Expert, self).__init__()
|
12 |
+
|
13 |
+
self.use_softmax = use_softmax
|
14 |
+
|
15 |
+
self.net = nn.Sequential(
|
16 |
+
nn.Linear(input_dim, hidden_dim),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.Linear(hidden_dim, output_dim),
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return (
|
23 |
+
self.net(x) if not self.use_softmax else torch.softmax(self.net(x), dim=1)
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class DynamicGatingNetwork(nn.Module):
|
28 |
+
def __init__(self, hidden_dim=64, embed_dim=64, dtype=torch.bfloat16):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
# 处理时间步
|
32 |
+
self.time_proj = Timesteps(
|
33 |
+
hidden_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
34 |
+
)
|
35 |
+
self.timestep_embedding = TimestepEmbedding(hidden_dim, embed_dim)
|
36 |
+
self.timestep_embedding = self.timestep_embedding.to(dtype=torch.bfloat16)
|
37 |
+
# 处理 noise_latent
|
38 |
+
self.noise_proj = nn.Linear(hidden_dim, hidden_dim)
|
39 |
+
self.dtype = dtype
|
40 |
+
|
41 |
+
# 权重计算
|
42 |
+
self.gate = nn.Sequential(
|
43 |
+
nn.Linear(hidden_dim, hidden_dim),
|
44 |
+
nn.ReLU(),
|
45 |
+
nn.Linear(hidden_dim, 20), # 生成两个权重
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, condition_latents, noise_latent, timestep):
|
49 |
+
"""
|
50 |
+
global_latents: (bs, 1024, 64)
|
51 |
+
noise_latent: (bs, 1024, 64)
|
52 |
+
timestep: (bs,)
|
53 |
+
"""
|
54 |
+
bs, seq_len, hidden_dim = condition_latents.shape
|
55 |
+
|
56 |
+
# 处理 timestep
|
57 |
+
time_emb = self.time_proj(timestep) # (bs, hidden_dim)
|
58 |
+
time_emb = time_emb.to(self.dtype)
|
59 |
+
time_emb = self.timestep_embedding(time_emb) # (bs, embed_dim)
|
60 |
+
|
61 |
+
time_emb = time_emb.unsqueeze(1).expand(
|
62 |
+
-1, seq_len, -1
|
63 |
+
) # (bs, 1024, embed_dim)
|
64 |
+
|
65 |
+
# 处理 noise_latent
|
66 |
+
noise_emb = self.noise_proj(noise_latent) # (bs, 1024, 64)
|
67 |
+
# 拼接所有输入
|
68 |
+
# fused_input = torch.cat([condition_latents, noise_emb, time_emb], dim=2) # (bs, 1024, 64+64+128)
|
69 |
+
fused_input = condition_latents + noise_emb + time_emb
|
70 |
+
# 计算权重
|
71 |
+
weight = self.gate(fused_input) # (bs, 1024, 2)
|
72 |
+
weight = F.softmax(weight, dim=2) # 归一化
|
73 |
+
|
74 |
+
return weight
|
75 |
+
|
76 |
+
class MoGLE(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
num_experts=20,
|
80 |
+
input_dim=64,
|
81 |
+
hidden_dim=32,
|
82 |
+
output_dim=64,
|
83 |
+
has_expert=True,
|
84 |
+
has_gating=True,
|
85 |
+
weight_is_scale=False,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
expert_model = None
|
89 |
+
if has_expert:
|
90 |
+
expert_model = Expert
|
91 |
+
else:
|
92 |
+
expert_model = nn.Identity
|
93 |
+
self.global_expert = expert_model(
|
94 |
+
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
|
95 |
+
)
|
96 |
+
self.local_experts = nn.ModuleList(
|
97 |
+
[
|
98 |
+
expert_model(
|
99 |
+
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
|
100 |
+
)
|
101 |
+
for _ in range(num_experts - 1)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
# self.gating = Gating(input_dim=input_dim, num_experts=num_experts)
|
105 |
+
if has_gating:
|
106 |
+
self.gating = DynamicGatingNetwork()
|
107 |
+
else:
|
108 |
+
self.gating = nn.Identity()
|
109 |
+
|
110 |
+
self.weight_is_scale = weight_is_scale
|
111 |
+
|
112 |
+
def forward(self, x: torch.Tensor, noise_latent, timestep):
|
113 |
+
global_mask = x[:, 0] # bs 1024 64
|
114 |
+
local_mask = x[:, 1:] # bs 19 1024 64
|
115 |
+
if not isinstance(self.gating, nn.Identity):
|
116 |
+
weights = self.gating.forward(
|
117 |
+
global_mask, noise_latent=noise_latent, timestep=timestep
|
118 |
+
) # bs 1024 20
|
119 |
+
|
120 |
+
_, num_local, h, w = local_mask.shape
|
121 |
+
global_output = self.global_expert(global_mask).unsqueeze(1)
|
122 |
+
local_outputs = torch.stack(
|
123 |
+
[self.local_experts[i](local_mask[:, i]) for i in range(num_local)], dim=1
|
124 |
+
) # (bs, 19, 1024, 64)
|
125 |
+
global_local_outputs = torch.cat(
|
126 |
+
[global_output, local_outputs], dim=1
|
127 |
+
) # bs 20 1024 64
|
128 |
+
|
129 |
+
if isinstance(self.gating, nn.Identity):
|
130 |
+
global_local_outputs = global_local_outputs.sum(dim=1)
|
131 |
+
return global_local_outputs
|
132 |
+
if self.weight_is_scale:
|
133 |
+
weights = torch.mean(weights, dim=1, keepdim=True) # bs 1 20
|
134 |
+
# print("gating scale")
|
135 |
+
|
136 |
+
weights_expanded = weights.unsqueeze(-1)
|
137 |
+
output = (global_local_outputs.permute(0, 2, 1, 3) * weights_expanded).sum(
|
138 |
+
dim=2
|
139 |
+
)
|
140 |
+
return output # bs 1024 64
|
src/train/callbacks.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
from PIL import Image, ImageFilter, ImageDraw
|
3 |
+
import numpy as np
|
4 |
+
from transformers import pipeline
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
from torchvision import transforms as T
|
9 |
+
try:
|
10 |
+
import wandb
|
11 |
+
except ImportError:
|
12 |
+
wandb = None
|
13 |
+
|
14 |
+
from ..flux.condition import Condition
|
15 |
+
from ..flux.generate import generate
|
16 |
+
|
17 |
+
|
18 |
+
class FaceMoGLECallback(L.Callback):
|
19 |
+
|
20 |
+
def __init__(self, run_name, training_config: dict = {}):
|
21 |
+
self.run_name, self.training_config = run_name, training_config
|
22 |
+
|
23 |
+
self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
|
24 |
+
self.save_interval = training_config.get("save_interval", 1000)
|
25 |
+
self.sample_interval = training_config.get("sample_interval", 1000)
|
26 |
+
self.save_path = training_config.get("save_path", "./runs")
|
27 |
+
|
28 |
+
self.wandb_config = training_config.get("wandb", None)
|
29 |
+
self.use_wandb = (
|
30 |
+
wandb is not None and os.environ.get("WANDB_API_KEY") is not None
|
31 |
+
)
|
32 |
+
|
33 |
+
self.total_steps = 0
|
34 |
+
|
35 |
+
def to_tensor(self, x):
|
36 |
+
return T.ToTensor()(x)
|
37 |
+
|
38 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
39 |
+
gradient_size = 0
|
40 |
+
max_gradient_size = 0
|
41 |
+
count = 0
|
42 |
+
for _, param in pl_module.named_parameters():
|
43 |
+
if param.grad is not None:
|
44 |
+
gradient_size += param.grad.norm(2).item()
|
45 |
+
max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
|
46 |
+
count += 1
|
47 |
+
if count > 0:
|
48 |
+
gradient_size /= count
|
49 |
+
|
50 |
+
self.total_steps += 1
|
51 |
+
|
52 |
+
# Print training progress every n steps
|
53 |
+
if self.use_wandb:
|
54 |
+
report_dict = {
|
55 |
+
"steps": batch_idx,
|
56 |
+
"steps": self.total_steps,
|
57 |
+
"epoch": trainer.current_epoch,
|
58 |
+
"gradient_size": gradient_size,
|
59 |
+
}
|
60 |
+
loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
|
61 |
+
report_dict["loss"] = loss_value
|
62 |
+
report_dict["t"] = pl_module.last_t
|
63 |
+
wandb.log(report_dict)
|
64 |
+
|
65 |
+
if self.total_steps % self.print_every_n_steps == 0:
|
66 |
+
print(
|
67 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
|
68 |
+
)
|
69 |
+
|
70 |
+
# Save LoRA weights at specified intervals
|
71 |
+
if self.total_steps % self.save_interval == 0:
|
72 |
+
print(
|
73 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
|
74 |
+
)
|
75 |
+
pl_module.save_lora(
|
76 |
+
f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
|
77 |
+
)
|
78 |
+
if hasattr(pl_module, "save_moe"):
|
79 |
+
pl_module.save_moe(
|
80 |
+
f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}/moe.pt"
|
81 |
+
)
|
82 |
+
|
83 |
+
# Generate and save a sample image at specified intervals
|
84 |
+
if self.total_steps % self.sample_interval == 0:
|
85 |
+
print(
|
86 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
|
87 |
+
)
|
88 |
+
self.generate_a_sample(
|
89 |
+
trainer,
|
90 |
+
pl_module,
|
91 |
+
f"{self.save_path}/{self.run_name}/output",
|
92 |
+
f"lora_{self.total_steps}",
|
93 |
+
batch["condition_type"][
|
94 |
+
0
|
95 |
+
], # Use the condition type from the current batch
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def generate_a_sample(
|
101 |
+
self,
|
102 |
+
trainer,
|
103 |
+
pl_module,
|
104 |
+
save_path,
|
105 |
+
file_name,
|
106 |
+
condition_type="super_resolution",
|
107 |
+
):
|
108 |
+
# TODO: change this two variables to parameters
|
109 |
+
target_size = trainer.training_config["dataset"]["target_size"]
|
110 |
+
position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
|
111 |
+
|
112 |
+
generator = torch.Generator(device=pl_module.device)
|
113 |
+
generator.manual_seed(42)
|
114 |
+
|
115 |
+
test_list = []
|
116 |
+
|
117 |
+
condition_img_path = "data/mmcelebahq/mask/27000.png"
|
118 |
+
|
119 |
+
# condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
|
120 |
+
test_list.append(
|
121 |
+
(
|
122 |
+
condition_img_path,
|
123 |
+
[0, 0],
|
124 |
+
"She is wearing lipstick. She is attractive and has straight hair.",
|
125 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
126 |
+
)
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
if not os.path.exists(save_path):
|
131 |
+
os.makedirs(save_path)
|
132 |
+
for i, (condition_img_path, position_delta, prompt, *others) in enumerate(
|
133 |
+
test_list
|
134 |
+
):
|
135 |
+
|
136 |
+
global_mask = Image.open(condition_img_path).convert("RGB")
|
137 |
+
mask_list = [self.to_tensor(global_mask)]
|
138 |
+
mask = Image.open(condition_img_path)
|
139 |
+
mask = np.array(mask)
|
140 |
+
for i in range(19):
|
141 |
+
local_mask = np.zeros_like(mask)
|
142 |
+
local_mask[mask == i] = 255
|
143 |
+
|
144 |
+
local_mask_rgb = Image.fromarray(local_mask).convert("RGB")
|
145 |
+
local_mask_tensor = self.to_tensor(local_mask_rgb)
|
146 |
+
mask_list.append(local_mask_tensor)
|
147 |
+
condition_img = torch.stack(mask_list, dim=0)
|
148 |
+
# condition_img = condition_img.unsqueeze(0)
|
149 |
+
|
150 |
+
condition = Condition(
|
151 |
+
condition_type=condition_type,
|
152 |
+
condition=condition_img,
|
153 |
+
position_delta=position_delta,
|
154 |
+
**(others[0] if others else {}),
|
155 |
+
)
|
156 |
+
|
157 |
+
res = generate(
|
158 |
+
pl_module.flux_pipe,
|
159 |
+
mogle=pl_module.mogle,
|
160 |
+
prompt=prompt,
|
161 |
+
conditions=[condition],
|
162 |
+
height=target_size,
|
163 |
+
width=target_size,
|
164 |
+
generator=generator,
|
165 |
+
model_config=pl_module.model_config,
|
166 |
+
default_lora=True,
|
167 |
+
)
|
168 |
+
res.images[0].save(
|
169 |
+
os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
|
170 |
+
)
|
src/train/data.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as T
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
class MMCelebAHQ(Dataset):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
root="data/mmcelebahq",
|
15 |
+
condition_size: int = 512,
|
16 |
+
target_size: int = 512,
|
17 |
+
condition_type: str = "depth",
|
18 |
+
drop_text_prob: float = 0.1,
|
19 |
+
drop_image_prob: float = 0.1,
|
20 |
+
return_pil_image: bool = False,
|
21 |
+
position_scale=1.0,
|
22 |
+
):
|
23 |
+
self.root = root
|
24 |
+
self.face_paths, self.mask_paths, self.prompts = self.get_face_mask_prompt()
|
25 |
+
self.condition_size = condition_size
|
26 |
+
self.target_size = target_size
|
27 |
+
self.condition_type = condition_type
|
28 |
+
self.drop_text_prob = drop_text_prob
|
29 |
+
self.drop_image_prob = drop_image_prob
|
30 |
+
self.return_pil_image = return_pil_image
|
31 |
+
self.position_scale = position_scale
|
32 |
+
|
33 |
+
self.to_tensor = T.ToTensor()
|
34 |
+
|
35 |
+
def get_face_mask_prompt(self):
|
36 |
+
face_paths = [
|
37 |
+
os.path.join(self.root, "face", f"{i}.jpg") for i in range(0, 27000)
|
38 |
+
]
|
39 |
+
mask_paths = [
|
40 |
+
os.path.join(self.root, "mask", f"{i}.png") for i in range(0, 27000)
|
41 |
+
]
|
42 |
+
with open(os.path.join(self.root, "text.json"), mode="r") as f:
|
43 |
+
prompts = json.load(f)
|
44 |
+
return face_paths, mask_paths, prompts
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.face_paths)
|
48 |
+
|
49 |
+
def __getitem__(self, idx):
|
50 |
+
image = Image.open(self.face_paths[idx]).convert("RGB")
|
51 |
+
prompts = self.prompts[f"{idx}.jpg"]
|
52 |
+
description = random.choices(prompts, k=1)[0].strip()
|
53 |
+
enable_scale = random.random() < 1
|
54 |
+
if not enable_scale:
|
55 |
+
condition_size = int(self.condition_size * self.position_scale)
|
56 |
+
position_scale = 1.0
|
57 |
+
else:
|
58 |
+
condition_size = self.condition_size
|
59 |
+
position_scale = self.position_scale
|
60 |
+
|
61 |
+
# Get the condition image
|
62 |
+
position_delta = np.array([0, 0])
|
63 |
+
|
64 |
+
mask = np.array(Image.open(self.mask_paths[idx]))
|
65 |
+
mask_list = [self.to_tensor(Image.open(self.mask_paths[idx]).convert("RGB"))]
|
66 |
+
for i in range(19):
|
67 |
+
local_mask = np.zeros_like(mask)
|
68 |
+
local_mask[mask == i] = 255
|
69 |
+
|
70 |
+
drop_image = random.random() < self.drop_image_prob
|
71 |
+
if drop_image:
|
72 |
+
local_mask = np.zeros_like(mask)
|
73 |
+
|
74 |
+
local_mask_rgb = Image.fromarray(local_mask).convert("RGB")
|
75 |
+
local_mask_tensor = self.to_tensor(local_mask_rgb)
|
76 |
+
mask_list.append(local_mask_tensor)
|
77 |
+
condition_img = torch.stack(mask_list,dim=0)
|
78 |
+
|
79 |
+
|
80 |
+
# Randomly drop text or image
|
81 |
+
drop_text = random.random() < self.drop_text_prob
|
82 |
+
# drop_image = random.random() < self.drop_image_prob
|
83 |
+
if drop_text:
|
84 |
+
description = ""
|
85 |
+
|
86 |
+
return {
|
87 |
+
"image": self.to_tensor(image),
|
88 |
+
"condition": condition_img,
|
89 |
+
# "condition": self.to_tensor(condition_img),
|
90 |
+
"condition_type": self.condition_type,
|
91 |
+
"description": description,
|
92 |
+
"position_delta": position_delta,
|
93 |
+
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
|
94 |
+
**({"position_scale": position_scale} if position_scale != 1.0 else {}),
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
|
src/train/model.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
from diffusers.pipelines import FluxPipeline
|
3 |
+
import torch
|
4 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
5 |
+
|
6 |
+
import prodigyopt
|
7 |
+
import os
|
8 |
+
from ..flux.transformer import tranformer_forward
|
9 |
+
from ..flux.condition import Condition
|
10 |
+
from ..flux.pipeline_tools import encode_images, prepare_text_input
|
11 |
+
|
12 |
+
from ..moe.mogle import MoGLE
|
13 |
+
|
14 |
+
|
15 |
+
class FaceMoGLE(L.LightningModule):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
flux_pipe_id: str,
|
19 |
+
lora_path: str = None,
|
20 |
+
lora_config: dict = None,
|
21 |
+
device: str = "cuda",
|
22 |
+
dtype: torch.dtype = torch.bfloat16,
|
23 |
+
model_config: dict = {},
|
24 |
+
optimizer_config: dict = None,
|
25 |
+
gradient_checkpointing: bool = False,
|
26 |
+
has_expert=True,
|
27 |
+
has_gating=True,
|
28 |
+
weight_is_scale=False
|
29 |
+
):
|
30 |
+
# Initialize the LightningModule
|
31 |
+
super().__init__()
|
32 |
+
self.model_config = model_config
|
33 |
+
self.optimizer_config = optimizer_config
|
34 |
+
|
35 |
+
# Load the Flux pipeline
|
36 |
+
self.flux_pipe: FluxPipeline = (
|
37 |
+
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
|
38 |
+
)
|
39 |
+
self.transformer = self.flux_pipe.transformer
|
40 |
+
self.transformer.gradient_checkpointing = gradient_checkpointing
|
41 |
+
self.transformer.train()
|
42 |
+
self.mogle = MoGLE(has_expert=has_expert,has_gating=has_gating,weight_is_scale=weight_is_scale)
|
43 |
+
self.mogle.train()
|
44 |
+
# Freeze the Flux pipeline
|
45 |
+
self.flux_pipe.text_encoder.requires_grad_(False).eval()
|
46 |
+
self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
|
47 |
+
self.flux_pipe.vae.requires_grad_(False).eval()
|
48 |
+
|
49 |
+
# Initialize LoRA layers
|
50 |
+
self.lora_layers = self.init_lora(lora_path, lora_config)
|
51 |
+
|
52 |
+
self.to(device).to(dtype)
|
53 |
+
|
54 |
+
def init_lora(self, lora_path: str, lora_config: dict):
|
55 |
+
assert lora_path or lora_config
|
56 |
+
if lora_path:
|
57 |
+
# TODO: Implement this
|
58 |
+
raise NotImplementedError
|
59 |
+
else:
|
60 |
+
self.transformer.add_adapter(LoraConfig(**lora_config))
|
61 |
+
# TODO: Check if this is correct (p.requires_grad)
|
62 |
+
lora_layers = filter(
|
63 |
+
lambda p: p.requires_grad, self.transformer.parameters()
|
64 |
+
)
|
65 |
+
return list(lora_layers)
|
66 |
+
|
67 |
+
def save_lora(self, path: str):
|
68 |
+
FluxPipeline.save_lora_weights(
|
69 |
+
save_directory=path,
|
70 |
+
transformer_lora_layers=get_peft_model_state_dict(self.transformer),
|
71 |
+
safe_serialization=True,
|
72 |
+
)
|
73 |
+
torch.save(self.mogle.state_dict(), os.path.join(path, "mogle.pt"))
|
74 |
+
|
75 |
+
|
76 |
+
def configure_optimizers(self):
|
77 |
+
# Freeze the transformer
|
78 |
+
self.transformer.requires_grad_(False)
|
79 |
+
opt_config = self.optimizer_config
|
80 |
+
|
81 |
+
# Set the trainable parameters
|
82 |
+
self.trainable_params = self.lora_layers + [p for p in self.mogle.parameters()]
|
83 |
+
|
84 |
+
# Unfreeze trainable parameters
|
85 |
+
for p in self.trainable_params:
|
86 |
+
p.requires_grad_(True)
|
87 |
+
|
88 |
+
# Initialize the optimizer
|
89 |
+
if opt_config["type"] == "AdamW":
|
90 |
+
optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
|
91 |
+
elif opt_config["type"] == "Prodigy":
|
92 |
+
optimizer = prodigyopt.Prodigy(
|
93 |
+
self.trainable_params,
|
94 |
+
**opt_config["params"],
|
95 |
+
)
|
96 |
+
elif opt_config["type"] == "SGD":
|
97 |
+
optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
|
98 |
+
else:
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
return optimizer
|
102 |
+
|
103 |
+
def training_step(self, batch, batch_idx):
|
104 |
+
step_loss = self.step(batch)
|
105 |
+
self.log_loss = (
|
106 |
+
step_loss.item()
|
107 |
+
if not hasattr(self, "log_loss")
|
108 |
+
else self.log_loss * 0.95 + step_loss.item() * 0.05
|
109 |
+
)
|
110 |
+
return step_loss
|
111 |
+
|
112 |
+
def step(self, batch):
|
113 |
+
imgs = batch["image"]
|
114 |
+
conditions = batch["condition"] # bsx20x3x512x512
|
115 |
+
condition_types = batch["condition_type"]
|
116 |
+
prompts = batch["description"]
|
117 |
+
position_delta = batch["position_delta"][0]
|
118 |
+
position_scale = float(batch.get("position_scale", [1.0])[0])
|
119 |
+
|
120 |
+
# Prepare inputs
|
121 |
+
with torch.no_grad():
|
122 |
+
# Prepare image input
|
123 |
+
x_0, img_ids = encode_images(self.flux_pipe, imgs)
|
124 |
+
|
125 |
+
# Prepare text input
|
126 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
|
127 |
+
self.flux_pipe, prompts
|
128 |
+
)
|
129 |
+
|
130 |
+
# Prepare t and x_t
|
131 |
+
t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
|
132 |
+
x_1 = torch.randn_like(x_0).to(self.device)
|
133 |
+
t_ = t.unsqueeze(1).unsqueeze(1)
|
134 |
+
x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
|
135 |
+
|
136 |
+
# Prepare conditions # condition_latents \in bsx64x32x32 -> bsx(32x32)x64, condition_ids \in [1024, 3]
|
137 |
+
# intial conditions shape [bs, 19, 3, 512, 512] reshape to [bsx19, 3, 512, 512]
|
138 |
+
c_bs, c_classes, c_channels, c_h, c_w = conditions.shape
|
139 |
+
conditions = conditions.view(c_bs * c_classes, c_channels, c_h, c_w)
|
140 |
+
|
141 |
+
condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
|
142 |
+
condition_latents_reshape = condition_latents.reshape(c_bs, c_classes, *condition_latents.shape[-2:]) # bs 20 1024 64
|
143 |
+
condition_latents = self.mogle.forward(condition_latents_reshape,noise_latent=x_t,timestep=t)
|
144 |
+
# conditions shape [bsx19, 1024, 64] # this is condition features
|
145 |
+
# condition_ids shape [1024, 3] # this is position embedding
|
146 |
+
# help me design a simple MoE to fuse 19 condition_latents
|
147 |
+
|
148 |
+
|
149 |
+
# Add position delta
|
150 |
+
condition_ids[:, 1] += position_delta[0]
|
151 |
+
condition_ids[:, 2] += position_delta[1]
|
152 |
+
|
153 |
+
if position_scale != 1.0:
|
154 |
+
scale_bias = (position_scale - 1.0) / 2
|
155 |
+
condition_ids[:, 1] *= position_scale
|
156 |
+
condition_ids[:, 2] *= position_scale
|
157 |
+
condition_ids[:, 1] += scale_bias
|
158 |
+
condition_ids[:, 2] += scale_bias
|
159 |
+
|
160 |
+
# Prepare condition type
|
161 |
+
condition_type_ids = torch.tensor(
|
162 |
+
[
|
163 |
+
Condition.get_type_id(condition_type)
|
164 |
+
for condition_type in condition_types
|
165 |
+
]
|
166 |
+
).to(self.device)
|
167 |
+
condition_type_ids = (
|
168 |
+
torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
|
169 |
+
).unsqueeze(1)
|
170 |
+
|
171 |
+
# Prepare guidance
|
172 |
+
guidance = (
|
173 |
+
torch.ones_like(t).to(self.device)
|
174 |
+
if self.transformer.config.guidance_embeds
|
175 |
+
else None
|
176 |
+
)
|
177 |
+
# Forward pass
|
178 |
+
transformer_out = tranformer_forward(
|
179 |
+
self.transformer,
|
180 |
+
# Model config
|
181 |
+
model_config=self.model_config,
|
182 |
+
# Inputs of the condition (new feature)
|
183 |
+
condition_latents=condition_latents,
|
184 |
+
condition_ids=condition_ids,
|
185 |
+
condition_type_ids=condition_type_ids,
|
186 |
+
# Inputs to the original transformer
|
187 |
+
hidden_states=x_t,
|
188 |
+
timestep=t,
|
189 |
+
guidance=guidance,
|
190 |
+
pooled_projections=pooled_prompt_embeds,
|
191 |
+
encoder_hidden_states=prompt_embeds,
|
192 |
+
txt_ids=text_ids,
|
193 |
+
img_ids=img_ids,
|
194 |
+
joint_attention_kwargs=None,
|
195 |
+
return_dict=False,
|
196 |
+
)
|
197 |
+
pred = transformer_out[0]
|
198 |
+
# Compute loss
|
199 |
+
loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
|
200 |
+
self.last_t = t.mean().item()
|
201 |
+
return loss
|
weights/mogle.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b071a349d1e8f922d32a066014f9cc80b39f8db55043d8bdf04e79e156d4f243
|
3 |
+
size 238252
|
weights/pytorch_lora_weights.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b2202f249a33252ce4f630db2f9536a28caf4b90e27927633f1f3bbb121f774
|
3 |
+
size 29066872
|