XavierJiezou commited on
Commit
48ed5ae
·
verified ·
1 Parent(s): 3f02207

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2024] [Zhenxiong Tan]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import yaml
5
+ import numpy as np
6
+ from torchvision.models import convnext_base, convnext_small
7
+ from torch import nn as nn
8
+ import facer
9
+ from torch import Tensor
10
+ import math
11
+ from typing import Any, Optional, Tuple, Type
12
+ from torch.nn import functional as F
13
+ import torchvision
14
+ from torchvision import transforms as T
15
+ from src.flux.generate import generate
16
+ from diffusers.pipelines import FluxPipeline
17
+ from src.flux.condition import Condition
18
+ from src.moe.mogle import MoGLE
19
+
20
+
21
+ class LayerNorm2d(nn.Module):
22
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.ones(num_channels))
25
+ self.bias = nn.Parameter(torch.zeros(num_channels))
26
+ self.eps = eps
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ u = x.mean(1, keepdim=True)
30
+ s = (x - u).pow(2).mean(1, keepdim=True)
31
+ x = (x - u) / torch.sqrt(s + self.eps)
32
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
33
+ return x
34
+
35
+
36
+ class MLP(nn.Module):
37
+ def __init__(
38
+ self,
39
+ input_dim: int,
40
+ hidden_dim: int,
41
+ output_dim: int,
42
+ num_layers: int,
43
+ sigmoid_output: bool = False,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.num_layers = num_layers
47
+ h = [hidden_dim] * (num_layers - 1)
48
+ self.layers = nn.ModuleList(
49
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
50
+ )
51
+ self.sigmoid_output = sigmoid_output
52
+
53
+ def forward(self, x):
54
+ for i, layer in enumerate(self.layers):
55
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
56
+ if self.sigmoid_output:
57
+ x = F.sigmoid(x)
58
+ return x
59
+
60
+
61
+ class FaceDecoder(nn.Module):
62
+ def __init__(
63
+ self,
64
+ *,
65
+ transformer_dim: 256,
66
+ transformer: nn.Module,
67
+ activation: Type[nn.Module] = nn.GELU,
68
+ ) -> None:
69
+
70
+ super().__init__()
71
+ self.transformer_dim = transformer_dim
72
+ self.transformer = transformer
73
+
74
+ self.background_token = nn.Embedding(1, transformer_dim)
75
+ self.neck_token = nn.Embedding(1, transformer_dim)
76
+ self.face_token = nn.Embedding(1, transformer_dim)
77
+ self.cloth_token = nn.Embedding(1, transformer_dim)
78
+ self.rightear_token = nn.Embedding(1, transformer_dim)
79
+ self.leftear_token = nn.Embedding(1, transformer_dim)
80
+ self.rightbro_token = nn.Embedding(1, transformer_dim)
81
+ self.leftbro_token = nn.Embedding(1, transformer_dim)
82
+ self.righteye_token = nn.Embedding(1, transformer_dim)
83
+ self.lefteye_token = nn.Embedding(1, transformer_dim)
84
+ self.nose_token = nn.Embedding(1, transformer_dim)
85
+ self.innermouth_token = nn.Embedding(1, transformer_dim)
86
+ self.lowerlip_token = nn.Embedding(1, transformer_dim)
87
+ self.upperlip_token = nn.Embedding(1, transformer_dim)
88
+ self.hair_token = nn.Embedding(1, transformer_dim)
89
+ self.glass_token = nn.Embedding(1, transformer_dim)
90
+ self.hat_token = nn.Embedding(1, transformer_dim)
91
+ self.earring_token = nn.Embedding(1, transformer_dim)
92
+ self.necklace_token = nn.Embedding(1, transformer_dim)
93
+
94
+ self.output_upscaling = nn.Sequential(
95
+ nn.ConvTranspose2d(
96
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
97
+ ),
98
+ LayerNorm2d(transformer_dim // 4),
99
+ activation(),
100
+ nn.ConvTranspose2d(
101
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
102
+ ),
103
+ activation(),
104
+ )
105
+
106
+ self.output_hypernetwork_mlps = MLP(
107
+ transformer_dim, transformer_dim, transformer_dim // 8, 3
108
+ )
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
115
+ """
116
+ image_embeddings - torch.Size([1, 256, 128, 128])
117
+ image_pe - torch.Size([1, 256, 128, 128])
118
+ """
119
+ output_tokens = torch.cat(
120
+ [
121
+ self.background_token.weight,
122
+ self.neck_token.weight,
123
+ self.face_token.weight,
124
+ self.cloth_token.weight,
125
+ self.rightear_token.weight,
126
+ self.leftear_token.weight,
127
+ self.rightbro_token.weight,
128
+ self.leftbro_token.weight,
129
+ self.righteye_token.weight,
130
+ self.lefteye_token.weight,
131
+ self.nose_token.weight,
132
+ self.innermouth_token.weight,
133
+ self.lowerlip_token.weight,
134
+ self.upperlip_token.weight,
135
+ self.hair_token.weight,
136
+ self.glass_token.weight,
137
+ self.hat_token.weight,
138
+ self.earring_token.weight,
139
+ self.necklace_token.weight,
140
+ ],
141
+ dim=0,
142
+ )
143
+
144
+ tokens = output_tokens.unsqueeze(0).expand(
145
+ image_embeddings.size(0), -1, -1
146
+ ) ##### torch.Size([4, 11, 256])
147
+
148
+ src = image_embeddings ##### torch.Size([4, 256, 128, 128])
149
+ pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
150
+ b, c, h, w = src.shape
151
+
152
+ # Run the transformer
153
+ hs, src = self.transformer(
154
+ src, pos_src, tokens
155
+ ) ####### hs - torch.Size([BS, 11, 256]), src - torch.Size([BS, 16348, 256])
156
+ mask_token_out = hs[:, :, :]
157
+
158
+ src = src.transpose(1, 2).view(b, c, h, w) ##### torch.Size([4, 256, 128, 128])
159
+ upscaled_embedding = self.output_upscaling(
160
+ src
161
+ ) ##### torch.Size([4, 32, 512, 512])
162
+ hyper_in = self.output_hypernetwork_mlps(
163
+ mask_token_out
164
+ ) ##### torch.Size([1, 11, 32])
165
+ b, c, h, w = upscaled_embedding.shape
166
+ seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
167
+ b, -1, h, w
168
+ ) ##### torch.Size([1, 11, 512, 512])
169
+
170
+ return seg_output
171
+
172
+
173
+ class PositionEmbeddingRandom(nn.Module):
174
+ """
175
+ Positional encoding using random spatial frequencies.
176
+ """
177
+
178
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
179
+ super().__init__()
180
+ if scale is None or scale <= 0.0:
181
+ scale = 1.0
182
+ self.register_buffer(
183
+ "positional_encoding_gaussian_matrix",
184
+ scale * torch.randn((2, num_pos_feats)),
185
+ )
186
+
187
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
188
+ """Positionally encode points that are normalized to [0,1]."""
189
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
190
+ coords = 2 * coords - 1
191
+ coords = coords @ self.positional_encoding_gaussian_matrix
192
+ coords = 2 * np.pi * coords
193
+ # outputs d_1 x ... x d_n x C shape
194
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
195
+
196
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
197
+ """Generate positional encoding for a grid of the specified size."""
198
+ h, w = size
199
+ device: Any = self.positional_encoding_gaussian_matrix.device
200
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
201
+ y_embed = grid.cumsum(dim=0) - 0.5
202
+ x_embed = grid.cumsum(dim=1) - 0.5
203
+ y_embed = y_embed / h
204
+ x_embed = x_embed / w
205
+
206
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
207
+ return pe.permute(2, 0, 1) # C x H x W
208
+
209
+ def forward_with_coords(
210
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
211
+ ) -> torch.Tensor:
212
+ """Positionally encode points that are not normalized to [0,1]."""
213
+ coords = coords_input.clone()
214
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
215
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
216
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
217
+
218
+
219
+ class TwoWayTransformer(nn.Module):
220
+ def __init__(
221
+ self,
222
+ depth: int,
223
+ embedding_dim: int,
224
+ num_heads: int,
225
+ mlp_dim: int,
226
+ activation: Type[nn.Module] = nn.ReLU,
227
+ attention_downsample_rate: int = 2,
228
+ ) -> None:
229
+ """
230
+ A transformer decoder that attends to an input image using
231
+ queries whose positional embedding is supplied.
232
+
233
+ Args:
234
+ depth (int): number of layers in the transformer
235
+ embedding_dim (int): the channel dimension for the input embeddings
236
+ num_heads (int): the number of heads for multihead attention. Must
237
+ divide embedding_dim
238
+ mlp_dim (int): the channel dimension internal to the MLP block
239
+ activation (nn.Module): the activation to use in the MLP block
240
+ """
241
+ super().__init__()
242
+ self.depth = depth
243
+ self.embedding_dim = embedding_dim
244
+ self.num_heads = num_heads
245
+ self.mlp_dim = mlp_dim
246
+ self.layers = nn.ModuleList()
247
+
248
+ for i in range(depth):
249
+ self.layers.append(
250
+ TwoWayAttentionBlock(
251
+ embedding_dim=embedding_dim,
252
+ num_heads=num_heads,
253
+ mlp_dim=mlp_dim,
254
+ activation=activation,
255
+ attention_downsample_rate=attention_downsample_rate,
256
+ skip_first_layer_pe=(i == 0),
257
+ )
258
+ )
259
+
260
+ self.final_attn_token_to_image = Attention(
261
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
262
+ )
263
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
264
+
265
+ def forward(
266
+ self,
267
+ image_embedding: Tensor,
268
+ image_pe: Tensor,
269
+ point_embedding: Tensor,
270
+ ) -> Tuple[Tensor, Tensor]:
271
+ """
272
+ Args:
273
+ image_embedding (torch.Tensor): image to attend to. Should be shape
274
+ B x embedding_dim x h x w for any h and w.
275
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
276
+ have the same shape as image_embedding.
277
+ point_embedding (torch.Tensor): the embedding to add to the query points.
278
+ Must have shape B x N_points x embedding_dim for any N_points.
279
+
280
+ Returns:
281
+ torch.Tensor: the processed point_embedding
282
+ torch.Tensor: the processed image_embedding
283
+ """
284
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
285
+ bs, c, h, w = image_embedding.shape
286
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
287
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
288
+
289
+ # Prepare queries
290
+ queries = point_embedding
291
+ keys = image_embedding
292
+
293
+ # Apply transformer blocks and final layernorm
294
+ for layer in self.layers:
295
+ queries, keys = layer(
296
+ queries=queries,
297
+ keys=keys,
298
+ query_pe=point_embedding,
299
+ key_pe=image_pe,
300
+ )
301
+
302
+ # Apply the final attention layer from the points to the image
303
+ q = queries + point_embedding
304
+ k = keys + image_pe
305
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
306
+ queries = queries + attn_out
307
+ queries = self.norm_final_attn(queries)
308
+
309
+ return queries, keys
310
+
311
+
312
+ class MLPBlock(nn.Module):
313
+ def __init__(
314
+ self,
315
+ embedding_dim: int,
316
+ mlp_dim: int,
317
+ act: Type[nn.Module] = nn.GELU,
318
+ ) -> None:
319
+ super().__init__()
320
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
321
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
322
+ self.act = act()
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ return self.lin2(self.act(self.lin1(x)))
326
+
327
+
328
+ class TwoWayAttentionBlock(nn.Module):
329
+ def __init__(
330
+ self,
331
+ embedding_dim: int,
332
+ num_heads: int,
333
+ mlp_dim: int = 2048,
334
+ activation: Type[nn.Module] = nn.ReLU,
335
+ attention_downsample_rate: int = 2,
336
+ skip_first_layer_pe: bool = False,
337
+ ) -> None:
338
+ """
339
+ A transformer block with four layers: (1) self-attention of sparse
340
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
341
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
342
+ inputs.
343
+
344
+ Arguments:
345
+ embedding_dim (int): the channel dimension of the embeddings
346
+ num_heads (int): the number of heads in the attention layers
347
+ mlp_dim (int): the hidden dimension of the mlp block
348
+ activation (nn.Module): the activation of the mlp block
349
+ skip_first_layer_pe (bool): skip the PE on the first layer
350
+ """
351
+ super().__init__()
352
+ self.self_attn = Attention(embedding_dim, num_heads)
353
+ self.norm1 = nn.LayerNorm(embedding_dim)
354
+
355
+ self.cross_attn_token_to_image = Attention(
356
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
357
+ )
358
+ self.norm2 = nn.LayerNorm(embedding_dim)
359
+
360
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
361
+ self.norm3 = nn.LayerNorm(embedding_dim)
362
+
363
+ self.norm4 = nn.LayerNorm(embedding_dim)
364
+ self.cross_attn_image_to_token = Attention(
365
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
366
+ )
367
+
368
+ self.skip_first_layer_pe = skip_first_layer_pe
369
+
370
+ def forward(
371
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
372
+ ) -> Tuple[Tensor, Tensor]:
373
+ # Self attention block
374
+ if self.skip_first_layer_pe:
375
+ queries = self.self_attn(q=queries, k=queries, v=queries)
376
+ else:
377
+ q = queries + query_pe
378
+ attn_out = self.self_attn(q=q, k=q, v=queries)
379
+ queries = queries + attn_out
380
+ queries = self.norm1(queries)
381
+
382
+ # Cross attention block, tokens attending to image embedding
383
+ q = queries + query_pe
384
+ k = keys + key_pe
385
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
386
+ queries = queries + attn_out
387
+ queries = self.norm2(queries)
388
+
389
+ # MLP block
390
+ mlp_out = self.mlp(queries)
391
+ queries = queries + mlp_out
392
+ queries = self.norm3(queries)
393
+
394
+ # Cross attention block, image embedding attending to tokens
395
+ q = queries + query_pe
396
+ k = keys + key_pe
397
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
398
+ keys = keys + attn_out
399
+ keys = self.norm4(keys)
400
+
401
+ return queries, keys
402
+
403
+
404
+ class Attention(nn.Module):
405
+ """
406
+ An attention layer that allows for downscaling the size of the embedding
407
+ after projection to queries, keys, and values.
408
+ """
409
+
410
+ def __init__(
411
+ self,
412
+ embedding_dim: int,
413
+ num_heads: int,
414
+ downsample_rate: int = 1,
415
+ ) -> None:
416
+ super().__init__()
417
+ self.embedding_dim = embedding_dim
418
+ self.internal_dim = embedding_dim // downsample_rate
419
+ self.num_heads = num_heads
420
+ assert (
421
+ self.internal_dim % num_heads == 0
422
+ ), "num_heads must divide embedding_dim."
423
+
424
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
425
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
426
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
427
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
428
+
429
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
430
+ b, n, c = x.shape
431
+ x = x.reshape(b, n, num_heads, c // num_heads)
432
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
433
+
434
+ def _recombine_heads(self, x: Tensor) -> Tensor:
435
+ b, n_heads, n_tokens, c_per_head = x.shape
436
+ x = x.transpose(1, 2)
437
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
438
+
439
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
440
+ # Input projections
441
+ q = self.q_proj(q)
442
+ k = self.k_proj(k)
443
+ v = self.v_proj(v)
444
+
445
+ # Separate into heads
446
+ q = self._separate_heads(q, self.num_heads)
447
+ k = self._separate_heads(k, self.num_heads)
448
+ v = self._separate_heads(v, self.num_heads)
449
+
450
+ # Attention
451
+ _, _, _, c_per_head = q.shape
452
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
453
+ attn = attn / math.sqrt(c_per_head)
454
+ attn = torch.softmax(attn, dim=-1)
455
+
456
+ # Get output
457
+ out = attn @ v
458
+ out = self._recombine_heads(out)
459
+ out = self.out_proj(out)
460
+
461
+ return out
462
+
463
+
464
+ class SegfaceMLP(nn.Module):
465
+ """
466
+ Linear Embedding.
467
+ """
468
+
469
+ def __init__(self, input_dim):
470
+ super().__init__()
471
+ self.proj = nn.Linear(input_dim, 256)
472
+
473
+ def forward(self, hidden_states: torch.Tensor):
474
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
475
+ hidden_states = self.proj(hidden_states)
476
+ return hidden_states
477
+
478
+
479
+ class SegFaceCeleb(nn.Module):
480
+ def __init__(self, input_resolution, model):
481
+ super(SegFaceCeleb, self).__init__()
482
+ self.input_resolution = input_resolution
483
+ self.model = model
484
+
485
+ if self.model == "convnext_base":
486
+ convnext = convnext_base(pretrained=False)
487
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
488
+ self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
489
+ self.multi_scale_features = []
490
+
491
+ if self.model == "convnext_small":
492
+ convnext = convnext_small(pretrained=False)
493
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
494
+ self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
495
+ self.multi_scale_features = []
496
+
497
+ if self.model == "convnext_tiny":
498
+ convnext = convnext_small(pretrained=False)
499
+ self.backbone = torch.nn.Sequential(*(list(convnext.children())[:-1]))
500
+ self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
501
+ self.multi_scale_features = []
502
+
503
+ embed_dim = 1024
504
+ out_chans = 256
505
+
506
+ self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
507
+
508
+ for name, module in self.backbone.named_modules():
509
+ if name in self.target_layer_names:
510
+ module.register_forward_hook(self.save_features_hook(name))
511
+
512
+ self.face_decoder = FaceDecoder(
513
+ transformer_dim=256,
514
+ transformer=TwoWayTransformer(
515
+ depth=2,
516
+ embedding_dim=256,
517
+ mlp_dim=2048,
518
+ num_heads=8,
519
+ ),
520
+ )
521
+
522
+ num_encoder_blocks = 4
523
+ if self.model in ["swin_base", "swinv2_base", "convnext_base"]:
524
+ hidden_sizes = [128, 256, 512, 1024] ### Swin Base and ConvNext Base
525
+ if self.model in ["resnet"]:
526
+ hidden_sizes = [256, 512, 1024, 2048] ### ResNet
527
+ if self.model in [
528
+ "swinv2_small",
529
+ "swinv2_tiny",
530
+ "convnext_small",
531
+ "convnext_tiny",
532
+ ]:
533
+ hidden_sizes = [
534
+ 96,
535
+ 192,
536
+ 384,
537
+ 768,
538
+ ] ### Swin Small/Tiny and ConvNext Small/Tiny
539
+ if self.model in ["mobilenet"]:
540
+ hidden_sizes = [24, 40, 112, 960] ### MobileNet
541
+ if self.model in ["efficientnet"]:
542
+ hidden_sizes = [48, 80, 176, 1280] ### EfficientNet
543
+ decoder_hidden_size = 256
544
+
545
+ mlps = []
546
+ for i in range(num_encoder_blocks):
547
+ mlp = SegfaceMLP(input_dim=hidden_sizes[i])
548
+ mlps.append(mlp)
549
+ self.linear_c = nn.ModuleList(mlps)
550
+
551
+ # The following 3 layers implement the ConvModule of the original implementation
552
+ self.linear_fuse = nn.Conv2d(
553
+ in_channels=decoder_hidden_size * num_encoder_blocks,
554
+ out_channels=decoder_hidden_size,
555
+ kernel_size=1,
556
+ bias=False,
557
+ )
558
+
559
+ def save_features_hook(self, name):
560
+ def hook(module, input, output):
561
+ if self.model in [
562
+ "swin_base",
563
+ "swinv2_base",
564
+ "swinv2_small",
565
+ "swinv2_tiny",
566
+ ]:
567
+ self.multi_scale_features.append(
568
+ output.permute(0, 3, 1, 2).contiguous()
569
+ ) ### Swin, Swinv2
570
+ if self.model in [
571
+ "convnext_base",
572
+ "convnext_small",
573
+ "convnext_tiny",
574
+ "mobilenet",
575
+ "efficientnet",
576
+ ]:
577
+ self.multi_scale_features.append(
578
+ output
579
+ ) ### ConvNext, ResNet, EfficientNet, MobileNet
580
+
581
+ return hook
582
+
583
+ def forward(self, x):
584
+ self.multi_scale_features.clear()
585
+
586
+ _, _, h, w = x.shape
587
+ features = self.backbone(x).squeeze()
588
+
589
+ batch_size = self.multi_scale_features[-1].shape[0]
590
+ all_hidden_states = ()
591
+ for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
592
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
593
+ encoder_hidden_state = mlp(encoder_hidden_state)
594
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
595
+ encoder_hidden_state = encoder_hidden_state.reshape(
596
+ batch_size, -1, height, width
597
+ )
598
+ # upsample
599
+ encoder_hidden_state = nn.functional.interpolate(
600
+ encoder_hidden_state,
601
+ size=self.multi_scale_features[0].size()[2:],
602
+ mode="bilinear",
603
+ align_corners=False,
604
+ )
605
+ all_hidden_states += (encoder_hidden_state,)
606
+
607
+ fused_states = self.linear_fuse(
608
+ torch.cat(all_hidden_states[::-1], dim=1)
609
+ ) #### torch.Size([BS, 256, 128, 128])
610
+ image_pe = self.pe_layer(
611
+ (fused_states.shape[2], fused_states.shape[3])
612
+ ).unsqueeze(0)
613
+ seg_output = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe)
614
+
615
+ return seg_output
616
+
617
+
618
+ # 模型和配置初始化封装类
619
+ class ImageGenerator:
620
+ def __init__(self):
621
+ self.args = self.get_args()
622
+ self.pipeline, self.moe_model = self.get_model(self.args)
623
+ with open(self.args.config_path, "r") as f:
624
+ self.model_config = yaml.safe_load(f)["model"]
625
+ self.farl = facer.face_parser(
626
+ "farl/celebm/448",
627
+ self.args.device,
628
+ model_path="https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt",
629
+ )
630
+ self.segface = SegFaceCeleb(512, "convnext_base").to(self.args.device)
631
+ checkpoint = torch.hub.load_state_dict_from_url("https://huggingface.co/kartiknarayan/SegFace/resolve/main/convnext_celeba_512/model_299.pt")
632
+ self.segface.load_state_dict(checkpoint["state_dict_backbone"])
633
+ self.segface.eval()
634
+ self.segface_transforms = torchvision.transforms.Compose(
635
+ [
636
+ torchvision.transforms.ToTensor(),
637
+ torchvision.transforms.Normalize(
638
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
639
+ ),
640
+ ]
641
+ )
642
+
643
+ self.seg_face_remap_dict = {
644
+ 0: 0, 1: 17, 2: 1, 3: 18, 4: 9, 5: 8, 6: 7, 7: 6,
645
+ 8: 5, 9: 4, 10: 2, 11: 10, 12: 12, 13: 11, 14: 13,
646
+ 15: 3, 16: 14, 17: 15, 18: 16,
647
+ }
648
+
649
+ self.palette = np.array(
650
+ [
651
+ (0, 0, 0), (204, 0, 0), (76, 153, 0), (204, 204, 0),
652
+ (204, 0, 204), (51, 51, 255), (255, 204, 204), (0, 255, 255),
653
+ (255, 0, 0), (102, 51, 0), (102, 204, 0), (255, 255, 0),
654
+ (0, 0, 153), (0, 0, 204), (255, 51, 153), (0, 204, 204),
655
+ (0, 51, 0), (255, 153, 51), (0, 204, 0),
656
+ ],
657
+ dtype=np.uint8,
658
+ )
659
+
660
+ self.org_labels = [
661
+ "background", "face", "nose", "eyeg", "le", "re", "lb", "rb",
662
+ "lr", "rr", "imouth", "ulip", "llip", "hair", "hat", "earr",
663
+ "neck_l", "neck", "cloth",
664
+ ]
665
+
666
+ self.new_labels = [
667
+ "background", "neck", "face", "cloth", "rr", "lr", "rb", "lb",
668
+ "re", "le", "nose", "imouth", "llip", "ulip", "hair", "eyeg",
669
+ "hat", "earr", "neck_l",
670
+ ]
671
+
672
+ @torch.no_grad()
673
+ def parse_face_with_farl(self, image):
674
+ image = image.resize((512, 512), Image.BICUBIC)
675
+ image_np = np.array(image)
676
+ image_pt = torch.tensor(image_np).to(self.args.device)
677
+ image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).float()
678
+ pred, _ = self.farl.net(image_pt / 255.0)
679
+ vis_seg_probs = pred.argmax(dim=1).detach().cpu().numpy()[0].astype(np.uint8)
680
+ remapped_mask = np.zeros_like(vis_seg_probs, dtype=np.uint8)
681
+ for i, pred_label in enumerate(self.new_labels):
682
+ if pred_label in self.org_labels:
683
+ remapped_mask[vis_seg_probs == i] = self.org_labels.index(pred_label)
684
+ vis_seg_probs = Image.fromarray(remapped_mask).convert("P")
685
+ vis_seg_probs.putpalette(self.palette.flatten())
686
+ return vis_seg_probs
687
+
688
+ @torch.no_grad()
689
+ def parse_face_with_segface(self, image):
690
+ image = image.resize((512, 512), Image.BICUBIC)
691
+ image = self.segface_transforms(image)
692
+ logits = self.segface(image.unsqueeze(0).to(self.args.device))
693
+ vis_seg_probs = logits.argmax(dim=1).detach().cpu().numpy()[0].astype(np.uint8)
694
+ new_mask = np.zeros_like(vis_seg_probs)
695
+ for old_idx, new_idx in self.seg_face_remap_dict.items():
696
+ new_mask[vis_seg_probs == old_idx] = new_idx
697
+ vis_seg_probs = Image.fromarray(new_mask).convert("P")
698
+ vis_seg_probs.putpalette(self.palette.flatten())
699
+ return vis_seg_probs
700
+
701
+ def get_args(self):
702
+ class Args:
703
+ pipe = "black-forest-labs/FLUX.1-dev"
704
+ lora_ckpt = "weights"
705
+ moe_ckpt = "weights/mogle.pt"
706
+ pretrained_ckpt = "weights/FLUX.1-dev"
707
+ device = "cuda" if torch.cuda.is_available() else "cpu"
708
+ size = 512
709
+ seed = 42
710
+ config_path = "config/Face-MoGLE.yaml"
711
+ return Args()
712
+
713
+ def get_model(self, args):
714
+ pipeline = FluxPipeline.from_pretrained(
715
+ args.pretrained_ckpt, torch_dtype=torch.bfloat16
716
+ )
717
+ pipeline.load_lora_weights(args.lora_ckpt, weight_name=f"pytorch_lora_weights.safetensors",)
718
+ pipeline.to(args.device)
719
+ moe_model = MoGLE()
720
+ moe_weight = torch.load(args.moe_ckpt, map_location="cpu")
721
+ moe_model.load_state_dict(moe_weight, strict=True)
722
+ moe_model = moe_model.to(device=args.device, dtype=torch.bfloat16)
723
+ moe_model.eval()
724
+ return pipeline, moe_model
725
+
726
+ def pack_data(self, mask_image: Image.Image):
727
+ mask = np.array(mask_image.convert("L"))
728
+ mask_list = [T.ToTensor()(mask_image.convert("RGB"))]
729
+ for i in range(19):
730
+ local_mask = np.zeros_like(mask)
731
+ local_mask[mask == i] = 255
732
+ local_mask_tensor = T.ToTensor()(Image.fromarray(local_mask).convert("RGB"))
733
+ mask_list.append(local_mask_tensor)
734
+ condition_img = torch.stack(mask_list, dim=0)
735
+ return Condition(condition_type="depth", condition=condition_img, position_delta=[0, 0])
736
+
737
+ def generate(self, prompt: str, mask_image: Image.Image, seed: int, num_inference_steps=28):
738
+ generator = torch.Generator().manual_seed(seed)
739
+ condition = self.pack_data(mask_image)
740
+ result = generate(
741
+ self.pipeline,
742
+ mogle=self.moe_model,
743
+ prompt=prompt,
744
+ conditions=[condition],
745
+ height=self.args.size,
746
+ width=self.args.size,
747
+ generator=generator,
748
+ model_config=self.model_config,
749
+ default_lora=True,
750
+ num_inference_steps=num_inference_steps
751
+ )
752
+ return result.images[0]
753
+
754
+
755
+ # 实例化生成器
756
+ generator = ImageGenerator()
757
+
758
+ examples = [
759
+
760
+ ["", "assets/mask2face/handou_seg.png", None, "FaRL", 42, 28],
761
+
762
+ ["", "assets/mask2face/black_seg.png", None, "FaRL", 42, 28],
763
+
764
+ ["She has red hair", "assets/multimodal/liuyifei_seg.png", None, "FaRL", 42, 28],
765
+
766
+ ["He is old", "assets/multimodal/musk_seg.png", None, "FaRL", 42, 28],
767
+
768
+ ["Curly-haired woman with glasses", None, None, "FaRL", 42, 28],
769
+
770
+ ["Man with beard and tie", None, None, "FaRL", 42, 28],
771
+
772
+ ]
773
+
774
+ # Gradio 界面(使用 Blocks)
775
+ with gr.Blocks(title="Controllable Face Generation with MoGLE") as demo:
776
+ gr.Markdown("## 🎭 Controllable Face Generation via Prompt + Face Parsing")
777
+
778
+ with gr.Row():
779
+ prompt = gr.Textbox(label="Text Prompt", placeholder="Describe the face you'd like to generate...")
780
+
781
+ with gr.Row():
782
+ with gr.Column():
783
+ mask_image = gr.Image(type="pil", label="🧩 Segmantic Mask (Optional)")
784
+ rgb_image = gr.Image(type="pil", label="🖼️ Facial Image (Optional)")
785
+ model_choice = gr.Radio(["FaRL", "SegFace"], label="Face Parsing Model", value="FaRL")
786
+ seed = gr.Slider(minimum=0, maximum=100000, step=1, value=42, label="Random Seed")
787
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=28, label="Sampling Step")
788
+ submit_btn = gr.Button("Generate")
789
+
790
+ with gr.Column():
791
+ gr.Markdown("### 🧠 Parsed Mask Preview")
792
+ preview_mask = gr.Image(label="Parsed Mask (from RGB)", interactive=False)
793
+ output_image = gr.Image(label="🎨 Generated Image")
794
+
795
+ def generate_wrapper(prompt, mask_image, rgb_image, model_choice, seed,num_inference_steps):
796
+ if mask_image is None and rgb_image is not None:
797
+ if model_choice == "FaRL":
798
+ mask_image = generator.parse_face_with_farl(rgb_image)
799
+ else:
800
+ mask_image = generator.parse_face_with_segface(rgb_image)
801
+ elif mask_image is None and rgb_image is None:
802
+ # raise gr.Error("请上传至少一个:语义分割图 或 RGB 人脸图像。")
803
+ mask_image = Image.new("RGB", size=(512, 512))
804
+ return mask_image, generator.generate(prompt, mask_image, seed,num_inference_steps)
805
+
806
+ submit_btn.click(
807
+ fn=generate_wrapper,
808
+ inputs=[prompt, mask_image, rgb_image, model_choice, seed,num_inference_steps],
809
+ outputs=[preview_mask, output_image]
810
+ )
811
+ gr.Examples(
812
+ examples=examples,
813
+ inputs=[prompt, mask_image, rgb_image, model_choice, seed, num_inference_steps],
814
+ outputs=[preview_mask, output_image],
815
+ fn=lambda *args: generate_wrapper(*args), # 直接引用已定义的函数
816
+ cache_examples=False,
817
+ label="Click any example below to try:"
818
+ )
819
+
820
+ if __name__ == "__main__":
821
+ demo.launch(server_name="0.0.0.0", server_port=5000, share=False)
assets/mask2face/black_seg.png ADDED
assets/mask2face/handou_seg.png ADDED
assets/multimodal/liuyifei_seg.png ADDED
assets/multimodal/musk_seg.png ADDED
config/Face-MoGLE.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flux_path: "black-forest-labs/FLUX.1-dev"
2
+ sd_path: "checkpoints/FLUX.1-dev"
3
+ dtype: "bfloat16"
4
+
5
+ model:
6
+ union_cond_attn: true
7
+ add_cond_attn: false
8
+ latent_lora: false
9
+
10
+ train:
11
+ batch_size: 4
12
+ accumulate_grad_batches: 1
13
+ dataloader_workers: 4
14
+ save_interval: 1000
15
+ sample_interval: 100
16
+ max_steps: -1
17
+ gradient_checkpointing: true
18
+ save_path: "runs/face-mogle"
19
+
20
+ condition_type: "depth"
21
+ dataset:
22
+ root: "data/mmcelebahq"
23
+ condition_size: 512
24
+ target_size: 512
25
+ drop_text_prob: 0.1
26
+ drop_image_prob: 0.1
27
+
28
+ lora_config:
29
+ r: 4
30
+ lora_alpha: 4
31
+ init_lora_weights: "gaussian"
32
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
33
+
34
+ optimizer:
35
+ type: "Prodigy"
36
+ params:
37
+ lr: 1
38
+ use_bias_correction: true
39
+ safeguard_warmup: true
40
+ weight_decay: 0.01
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ transformers
3
+ peft
4
+ opencv-python
5
+ protobuf
6
+ sentencepiece
7
+ gradio
8
+ jupyter
9
+ torchao
10
+ pyfacer
11
+ yaml
12
+
13
+ lightning
14
+ datasets
15
+ torchvision
16
+ prodigyopt
17
+ wandb
src/flux/__pycache__/block.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
src/flux/__pycache__/condition.cpython-311.pyc ADDED
Binary file (5.74 kB). View file
 
src/flux/__pycache__/generate.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
src/flux/__pycache__/lora_controller.cpython-311.pyc ADDED
Binary file (5.12 kB). View file
 
src/flux/__pycache__/pipeline_tools.cpython-311.pyc ADDED
Binary file (2.56 kB). View file
 
src/flux/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (7.85 kB). View file
 
src/flux/block.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+
6
+
7
+ def attn_forward(
8
+ attn: Attention,
9
+ hidden_states: torch.FloatTensor,
10
+ encoder_hidden_states: torch.FloatTensor = None,
11
+ condition_latents: torch.FloatTensor = None,
12
+ attention_mask: Optional[torch.FloatTensor] = None,
13
+ image_rotary_emb: Optional[torch.Tensor] = None,
14
+ cond_rotary_emb: Optional[torch.Tensor] = None,
15
+ model_config: Optional[Dict[str, Any]] = {},
16
+ ) -> torch.FloatTensor:
17
+ batch_size, _, _ = (
18
+ hidden_states.shape
19
+ if encoder_hidden_states is None
20
+ else encoder_hidden_states.shape
21
+ )
22
+
23
+ with enable_lora(
24
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
+ ):
26
+ # `sample` projections.
27
+ query = attn.to_q(hidden_states)
28
+ key = attn.to_k(hidden_states)
29
+ value = attn.to_v(hidden_states)
30
+ # print(query.shape,key.shape,value.shape) torch.Size([2, 1024, 3072]) torch.Size([2, 1024, 3072]) torch.Size([2, 1024, 3072])
31
+
32
+ inner_dim = key.shape[-1]
33
+ head_dim = inner_dim // attn.heads
34
+
35
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
38
+
39
+ if attn.norm_q is not None:
40
+ query = attn.norm_q(query)
41
+ if attn.norm_k is not None:
42
+ key = attn.norm_k(key)
43
+
44
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
45
+ if encoder_hidden_states is not None:
46
+ # print(hidden_states.shape,encoder_hidden_states.shape,condition_latents.shape) torch.Size([2, 1024, 3072]) torch.Size([2, 512, 3072]) torch.Size([2, 1024, 3072])
47
+ # `context` projections.
48
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
49
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
50
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
51
+
52
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
53
+ batch_size, -1, attn.heads, head_dim
54
+ ).transpose(1, 2)
55
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
56
+ batch_size, -1, attn.heads, head_dim
57
+ ).transpose(1, 2)
58
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
59
+ batch_size, -1, attn.heads, head_dim
60
+ ).transpose(1, 2)
61
+
62
+ if attn.norm_added_q is not None:
63
+ encoder_hidden_states_query_proj = attn.norm_added_q(
64
+ encoder_hidden_states_query_proj
65
+ )
66
+ if attn.norm_added_k is not None:
67
+ encoder_hidden_states_key_proj = attn.norm_added_k(
68
+ encoder_hidden_states_key_proj
69
+ )
70
+ # attention
71
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
72
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
73
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
74
+
75
+ if image_rotary_emb is not None:
76
+ from diffusers.models.embeddings import apply_rotary_emb
77
+
78
+ query = apply_rotary_emb(query, image_rotary_emb)
79
+ key = apply_rotary_emb(key, image_rotary_emb)
80
+
81
+ if condition_latents is not None:
82
+ cond_query = attn.to_q(condition_latents)
83
+ cond_key = attn.to_k(condition_latents)
84
+ cond_value = attn.to_v(condition_latents)
85
+
86
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
87
+ 1, 2
88
+ )
89
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
90
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
91
+ 1, 2
92
+ )
93
+ if attn.norm_q is not None:
94
+ cond_query = attn.norm_q(cond_query)
95
+ if attn.norm_k is not None:
96
+ cond_key = attn.norm_k(cond_key)
97
+
98
+ if cond_rotary_emb is not None:
99
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
100
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
101
+
102
+ if condition_latents is not None:
103
+ query = torch.cat([query, cond_query], dim=2)
104
+ key = torch.cat([key, cond_key], dim=2)
105
+ value = torch.cat([value, cond_value], dim=2)
106
+
107
+ if not model_config.get("union_cond_attn", True):
108
+ # If we don't want to use the union condition attention, we need to mask the attention
109
+ # between the hidden states and the condition latents
110
+ attention_mask = torch.ones(
111
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
112
+ )
113
+ condition_n = cond_query.shape[2]
114
+ attention_mask[-condition_n:, :-condition_n] = False
115
+ attention_mask[:-condition_n, -condition_n:] = False
116
+ elif model_config.get("independent_condition", False):
117
+ attention_mask = torch.ones(
118
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
119
+ )
120
+ condition_n = cond_query.shape[2]
121
+ attention_mask[-condition_n:, :-condition_n] = False
122
+ if hasattr(attn, "c_factor"):
123
+ attention_mask = torch.zeros(
124
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
125
+ )
126
+ condition_n = cond_query.shape[2]
127
+ bias = torch.log(attn.c_factor[0])
128
+ attention_mask[-condition_n:, :-condition_n] = bias
129
+ attention_mask[:-condition_n, -condition_n:] = bias
130
+ hidden_states = F.scaled_dot_product_attention(
131
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
132
+ )
133
+ hidden_states = hidden_states.transpose(1, 2).reshape(
134
+ batch_size, -1, attn.heads * head_dim
135
+ )
136
+ # print(f"hidden_states {hidden_states.shape}")
137
+ hidden_states = hidden_states.to(query.dtype)
138
+
139
+ if encoder_hidden_states is not None:
140
+ if condition_latents is not None:
141
+ encoder_hidden_states, hidden_states, condition_latents = (
142
+ hidden_states[:, : encoder_hidden_states.shape[1]],
143
+ hidden_states[
144
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
145
+ ],
146
+ hidden_states[:, -condition_latents.shape[1] :],
147
+ )
148
+ else:
149
+ encoder_hidden_states, hidden_states = (
150
+ hidden_states[:, : encoder_hidden_states.shape[1]],
151
+ hidden_states[:, encoder_hidden_states.shape[1] :],
152
+ )
153
+
154
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
155
+ # linear proj
156
+ hidden_states = attn.to_out[0](hidden_states)
157
+ # dropout
158
+ hidden_states = attn.to_out[1](hidden_states)
159
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
160
+
161
+ if condition_latents is not None:
162
+ condition_latents = attn.to_out[0](condition_latents)
163
+ condition_latents = attn.to_out[1](condition_latents)
164
+
165
+ return (
166
+ (hidden_states, encoder_hidden_states, condition_latents)
167
+ if condition_latents is not None
168
+ else (hidden_states, encoder_hidden_states)
169
+ )
170
+ elif condition_latents is not None:
171
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
172
+ hidden_states, condition_latents = (
173
+ hidden_states[:, : -condition_latents.shape[1]],
174
+ hidden_states[:, -condition_latents.shape[1] :],
175
+ )
176
+ # print(hidden_states.shape,condition_latents.shape) torch.Size([2, 1536, 3072]) torch.Size([2, 1024, 3072])
177
+ return hidden_states, condition_latents
178
+ else:
179
+ return hidden_states
180
+
181
+
182
+ def block_forward(
183
+ self,
184
+ hidden_states: torch.FloatTensor,
185
+ encoder_hidden_states: torch.FloatTensor,
186
+ condition_latents: torch.FloatTensor,
187
+ temb: torch.FloatTensor,
188
+ cond_temb: torch.FloatTensor,
189
+ cond_rotary_emb=None,
190
+ image_rotary_emb=None,
191
+ model_config: Optional[Dict[str, Any]] = {},
192
+ ):
193
+ use_cond = condition_latents is not None
194
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, emb=temb
197
+ )
198
+
199
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
200
+ self.norm1_context(encoder_hidden_states, emb=temb)
201
+ )
202
+ # print(norm_encoder_hidden_states.shape,c_gate_msa.shape,c_shift_mlp.shape,c_scale_mlp.shape,c_gate_mlp.shape)
203
+ # torch.Size([2, 512, 3072]) torch.Size([2, 3072]) torch.Size([2, 3072]) torch.Size([2, 3072]) torch.Size([2, 3072])
204
+
205
+ if use_cond:
206
+ (
207
+ norm_condition_latents,
208
+ cond_gate_msa,
209
+ cond_shift_mlp,
210
+ cond_scale_mlp,
211
+ cond_gate_mlp,
212
+ ) = self.norm1(condition_latents, emb=cond_temb)
213
+
214
+ # Attention.
215
+ result = attn_forward(
216
+ self.attn,
217
+ model_config=model_config,
218
+ hidden_states=norm_hidden_states,
219
+ encoder_hidden_states=norm_encoder_hidden_states,
220
+ condition_latents=norm_condition_latents if use_cond else None,
221
+ image_rotary_emb=image_rotary_emb,
222
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
223
+ )
224
+ attn_output, context_attn_output = result[:2]
225
+ cond_attn_output = result[2] if use_cond else None
226
+
227
+ # Process attention outputs for the `hidden_states`.
228
+ # 1. hidden_states
229
+ attn_output = gate_msa.unsqueeze(1) * attn_output
230
+ # print(hidden_states.shape,attn_output.shape) torch.Size([2, 1024, 3072]) torch.Size([2, 1024, 3072])
231
+ hidden_states = hidden_states + attn_output
232
+ # 2. encoder_hidden_states
233
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
234
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
235
+ # 3. condition_latents
236
+ if use_cond:
237
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
238
+ condition_latents = condition_latents + cond_attn_output
239
+ if model_config.get("add_cond_attn", False):
240
+ hidden_states += cond_attn_output
241
+
242
+ # LayerNorm + MLP.
243
+ # 1. hidden_states
244
+ norm_hidden_states = self.norm2(hidden_states)
245
+ norm_hidden_states = (
246
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
247
+ )
248
+ # 2. encoder_hidden_states
249
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
250
+ norm_encoder_hidden_states = (
251
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
252
+ )
253
+ # 3. condition_latents
254
+ if use_cond:
255
+ norm_condition_latents = self.norm2(condition_latents)
256
+ norm_condition_latents = (
257
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
258
+ + cond_shift_mlp[:, None]
259
+ )
260
+
261
+ # Feed-forward.
262
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
263
+ # 1. hidden_states
264
+ ff_output = self.ff(norm_hidden_states)
265
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
266
+ # 2. encoder_hidden_states
267
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
268
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
269
+ # 3. condition_latents
270
+ if use_cond:
271
+ cond_ff_output = self.ff(norm_condition_latents)
272
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
273
+
274
+ # Process feed-forward outputs.
275
+ hidden_states = hidden_states + ff_output
276
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
277
+ if use_cond:
278
+ condition_latents = condition_latents + cond_ff_output
279
+
280
+ # Clip to avoid overflow.
281
+ if encoder_hidden_states.dtype == torch.float16:
282
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
283
+
284
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
285
+
286
+
287
+ def single_block_forward(
288
+ self,
289
+ hidden_states: torch.FloatTensor,
290
+ temb: torch.FloatTensor,
291
+ image_rotary_emb=None,
292
+ condition_latents: torch.FloatTensor = None,
293
+ cond_temb: torch.FloatTensor = None,
294
+ cond_rotary_emb=None,
295
+ model_config: Optional[Dict[str, Any]] = {},
296
+ ):
297
+
298
+ using_cond = condition_latents is not None
299
+ residual = hidden_states
300
+ with enable_lora(
301
+ (
302
+ self.norm.linear,
303
+ self.proj_mlp,
304
+ ),
305
+ model_config.get("latent_lora", False),
306
+ ):
307
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
308
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
309
+ if using_cond:
310
+ residual_cond = condition_latents
311
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
312
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
313
+
314
+ attn_output = attn_forward(
315
+ self.attn,
316
+ model_config=model_config,
317
+ hidden_states=norm_hidden_states,
318
+ image_rotary_emb=image_rotary_emb,
319
+ **(
320
+ {
321
+ "condition_latents": norm_condition_latents,
322
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
323
+ }
324
+ if using_cond
325
+ else {}
326
+ ),
327
+ )
328
+ if using_cond:
329
+ attn_output, cond_attn_output = attn_output
330
+
331
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
332
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
333
+ gate = gate.unsqueeze(1)
334
+ hidden_states = gate * self.proj_out(hidden_states)
335
+ hidden_states = residual + hidden_states
336
+ if using_cond:
337
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
338
+ cond_gate = cond_gate.unsqueeze(1)
339
+ condition_latents = cond_gate * self.proj_out(condition_latents)
340
+ condition_latents = residual_cond + condition_latents
341
+
342
+ if hidden_states.dtype == torch.float16:
343
+ hidden_states = hidden_states.clip(-65504, 65504)
344
+
345
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
src/flux/condition.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, List, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .pipeline_tools import encode_images
9
+
10
+ condition_dict = {
11
+ "depth": 0,
12
+ "canny": 1,
13
+ "subject": 4,
14
+ "coloring": 6,
15
+ "deblurring": 7,
16
+ "depth_pred": 8,
17
+ "fill": 9,
18
+ "sr": 10,
19
+ "cartoon": 11,
20
+ }
21
+
22
+
23
+ class Condition(object):
24
+ def __init__(
25
+ self,
26
+ condition_type: str,
27
+ raw_img: Union[Image.Image, torch.Tensor] = None,
28
+ condition: Union[Image.Image, torch.Tensor] = None,
29
+ mask=None,
30
+ position_delta=None,
31
+ position_scale=1.0,
32
+ ) -> None:
33
+ self.condition_type = condition_type
34
+ assert raw_img is not None or condition is not None
35
+ if raw_img is not None:
36
+ self.condition = self.get_condition(condition_type, raw_img)
37
+ else:
38
+ self.condition = condition
39
+ self.position_delta = position_delta
40
+ self.position_scale = position_scale
41
+ # TODO: Add mask support
42
+ assert mask is None, "Mask not supported yet"
43
+
44
+ def get_condition(
45
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
46
+ ) -> Union[Image.Image, torch.Tensor]:
47
+ """
48
+ Returns the condition image.
49
+ """
50
+ if condition_type == "depth":
51
+ from transformers import pipeline
52
+
53
+ depth_pipe = pipeline(
54
+ task="depth-estimation",
55
+ model="LiheYoung/depth-anything-small-hf",
56
+ device="cuda",
57
+ )
58
+ source_image = raw_img.convert("RGB")
59
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
60
+ return condition_img
61
+ elif condition_type == "canny":
62
+ img = np.array(raw_img)
63
+ edges = cv2.Canny(img, 100, 200)
64
+ edges = Image.fromarray(edges).convert("RGB")
65
+ return edges
66
+ elif condition_type == "subject":
67
+ return raw_img
68
+ elif condition_type == "coloring":
69
+ return raw_img.convert("L").convert("RGB")
70
+ elif condition_type == "deblurring":
71
+ condition_image = (
72
+ raw_img.convert("RGB")
73
+ .filter(ImageFilter.GaussianBlur(10))
74
+ .convert("RGB")
75
+ )
76
+ return condition_image
77
+ elif condition_type == "fill":
78
+ return raw_img.convert("RGB")
79
+ elif condition_type == "cartoon":
80
+ return raw_img.convert("RGB")
81
+ return self.condition
82
+
83
+ @property
84
+ def type_id(self) -> int:
85
+ """
86
+ Returns the type id of the condition.
87
+ """
88
+ return condition_dict[self.condition_type]
89
+
90
+ @classmethod
91
+ def get_type_id(cls, condition_type: str) -> int:
92
+ """
93
+ Returns the type id of the condition.
94
+ """
95
+ return condition_dict[condition_type]
96
+
97
+ def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
98
+ """
99
+ Encodes the condition into tokens, ids and type_id.
100
+ """
101
+ if self.condition_type in [
102
+ "depth",
103
+ "canny",
104
+ "subject",
105
+ "coloring",
106
+ "deblurring",
107
+ "depth_pred",
108
+ "fill",
109
+ "sr",
110
+ "cartoon",
111
+ ]:
112
+ tokens, ids = encode_images(pipe, self.condition)
113
+ else:
114
+ raise NotImplementedError(
115
+ f"Condition type {self.condition_type} not implemented"
116
+ )
117
+ if self.position_delta is None and self.condition_type == "subject":
118
+ self.position_delta = [0, -self.condition.size[0] // 16]
119
+ if self.position_delta is not None:
120
+ ids[:, 1] += self.position_delta[0]
121
+ ids[:, 2] += self.position_delta[1]
122
+ if self.position_scale != 1.0:
123
+ scale_bias = (self.position_scale - 1.0) / 2
124
+ ids[:, 1] *= self.position_scale
125
+ ids[:, 2] *= self.position_scale
126
+ ids[:, 1] += scale_bias
127
+ ids[:, 2] += scale_bias
128
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
129
+ return tokens, ids, type_id
src/flux/generate.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml, os
3
+ from diffusers.pipelines import FluxPipeline, StableDiffusionPipeline
4
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
5
+ rescale_noise_cfg,
6
+ )
7
+ from diffusers.utils import deprecate, is_torch_xla_available
8
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
9
+ StableDiffusionPipelineOutput,
10
+ )
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from torchvision import transforms as T
13
+ from typing import List, Union, Optional, Dict, Any, Callable
14
+ from .transformer import tranformer_forward
15
+ from .condition import Condition
16
+
17
+ from diffusers.pipelines.flux.pipeline_flux import (
18
+ FluxPipelineOutput,
19
+ calculate_shift,
20
+ retrieve_timesteps,
21
+ np,
22
+ )
23
+
24
+
25
+ def get_config(config_path: str = None):
26
+ config_path = config_path or os.environ.get("XFL_CONFIG")
27
+ if not config_path:
28
+ return {}
29
+ with open(config_path, "r") as f:
30
+ config = yaml.safe_load(f)
31
+ return config
32
+
33
+
34
+ def prepare_params(
35
+ prompt: Union[str, List[str]] = None,
36
+ prompt_2: Optional[Union[str, List[str]]] = None,
37
+ height: Optional[int] = 512,
38
+ width: Optional[int] = 512,
39
+ num_inference_steps: int = 28,
40
+ timesteps: List[int] = None,
41
+ guidance_scale: float = 3.5,
42
+ num_images_per_prompt: Optional[int] = 1,
43
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
44
+ latents: Optional[torch.FloatTensor] = None,
45
+ prompt_embeds: Optional[torch.FloatTensor] = None,
46
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
47
+ output_type: Optional[str] = "pil",
48
+ return_dict: bool = True,
49
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
50
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
51
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
52
+ max_sequence_length: int = 512,
53
+ **kwargs: dict,
54
+ ):
55
+ return (
56
+ prompt,
57
+ prompt_2,
58
+ height,
59
+ width,
60
+ num_inference_steps,
61
+ timesteps,
62
+ guidance_scale,
63
+ num_images_per_prompt,
64
+ generator,
65
+ latents,
66
+ prompt_embeds,
67
+ pooled_prompt_embeds,
68
+ output_type,
69
+ return_dict,
70
+ joint_attention_kwargs,
71
+ callback_on_step_end,
72
+ callback_on_step_end_tensor_inputs,
73
+ max_sequence_length,
74
+ )
75
+
76
+
77
+ def seed_everything(seed: int = 42):
78
+ torch.backends.cudnn.deterministic = True
79
+ torch.manual_seed(seed)
80
+ np.random.seed(seed)
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+ @torch.no_grad()
89
+ def generate(
90
+ pipeline: FluxPipeline,
91
+ mogle=None,
92
+ conditions: List[Condition] = None,
93
+ config_path: str = None,
94
+ model_config: Optional[Dict[str, Any]] = {},
95
+ condition_scale: float = 1.0,
96
+ default_lora: bool = False,
97
+ **params: dict,
98
+ ):
99
+ model_config = model_config or get_config(config_path).get("model", {})
100
+ if condition_scale != 1:
101
+ for name, module in pipeline.transformer.named_modules():
102
+ if not name.endswith(".attn"):
103
+ continue
104
+ module.c_factor = torch.ones(1, 1) * condition_scale
105
+
106
+ self = pipeline
107
+ (
108
+ prompt,
109
+ prompt_2,
110
+ height,
111
+ width,
112
+ num_inference_steps,
113
+ timesteps,
114
+ guidance_scale,
115
+ num_images_per_prompt,
116
+ generator,
117
+ latents,
118
+ prompt_embeds,
119
+ pooled_prompt_embeds,
120
+ output_type,
121
+ return_dict,
122
+ joint_attention_kwargs,
123
+ callback_on_step_end,
124
+ callback_on_step_end_tensor_inputs,
125
+ max_sequence_length,
126
+ ) = prepare_params(**params)
127
+
128
+ height = height or self.default_sample_size * self.vae_scale_factor
129
+ width = width or self.default_sample_size * self.vae_scale_factor
130
+
131
+ # 1. Check inputs. Raise error if not correct
132
+ self.check_inputs(
133
+ prompt,
134
+ prompt_2,
135
+ height,
136
+ width,
137
+ prompt_embeds=prompt_embeds,
138
+ pooled_prompt_embeds=pooled_prompt_embeds,
139
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
140
+ max_sequence_length=max_sequence_length,
141
+ )
142
+
143
+ self._guidance_scale = guidance_scale
144
+ self._joint_attention_kwargs = joint_attention_kwargs
145
+ self._interrupt = False
146
+
147
+ # 2. Define call parameters
148
+ if prompt is not None and isinstance(prompt, str):
149
+ batch_size = 1
150
+ elif prompt is not None and isinstance(prompt, list):
151
+ batch_size = len(prompt)
152
+ else:
153
+ batch_size = prompt_embeds.shape[0]
154
+
155
+ device = self._execution_device
156
+
157
+ lora_scale = (
158
+ self.joint_attention_kwargs.get("scale", None)
159
+ if self.joint_attention_kwargs is not None
160
+ else None
161
+ )
162
+ (
163
+ prompt_embeds,
164
+ pooled_prompt_embeds,
165
+ text_ids,
166
+ ) = self.encode_prompt(
167
+ prompt=prompt,
168
+ prompt_2=prompt_2,
169
+ prompt_embeds=prompt_embeds,
170
+ pooled_prompt_embeds=pooled_prompt_embeds,
171
+ device=device,
172
+ num_images_per_prompt=num_images_per_prompt,
173
+ max_sequence_length=max_sequence_length,
174
+ lora_scale=lora_scale,
175
+ )
176
+
177
+ # 4. Prepare latent variables
178
+ num_channels_latents = self.transformer.config.in_channels // 4
179
+ latents, latent_image_ids = self.prepare_latents(
180
+ batch_size * num_images_per_prompt,
181
+ num_channels_latents,
182
+ height,
183
+ width,
184
+ prompt_embeds.dtype,
185
+ device,
186
+ generator,
187
+ latents,
188
+ )
189
+
190
+ # 4.1. Prepare conditions
191
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
192
+ use_condition = conditions is not None or []
193
+ if use_condition:
194
+ assert len(conditions) <= 1, "Only one condition is supported for now."
195
+ if not default_lora:
196
+ pipeline.set_adapters(conditions[0].condition_type)
197
+ for condition in conditions:
198
+ tokens, ids, type_id = condition.encode(self)
199
+ #print(tokens.shape) # 20 1024 64
200
+ # bs, mask_num, channel, h, w = tokens.shape
201
+ tokens_reshape = tokens.reshape(1, -1, *tokens.shape[1:])
202
+ #print(tokens.shape) # 1 1024 64
203
+ condition_latents.append(tokens_reshape) # [batch_size, token_n, token_dim]
204
+ condition_ids.append(ids) # [token_n, id_dim(3)]
205
+ condition_type_ids.append(type_id) # [token_n, 1]
206
+ condition_latents = torch.cat(condition_latents, dim=1)
207
+ condition_ids = torch.cat(condition_ids, dim=0)
208
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
209
+
210
+ # 5. Prepare timesteps
211
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
212
+ image_seq_len = latents.shape[1]
213
+ mu = calculate_shift(
214
+ image_seq_len,
215
+ self.scheduler.config.base_image_seq_len,
216
+ self.scheduler.config.max_image_seq_len,
217
+ self.scheduler.config.base_shift,
218
+ self.scheduler.config.max_shift,
219
+ )
220
+ timesteps, num_inference_steps = retrieve_timesteps(
221
+ self.scheduler,
222
+ num_inference_steps,
223
+ device,
224
+ timesteps,
225
+ sigmas,
226
+ mu=mu,
227
+ )
228
+ num_warmup_steps = max(
229
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
230
+ )
231
+ self._num_timesteps = len(timesteps)
232
+
233
+ # 6. Denoising loop
234
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
235
+ for i, t in enumerate(timesteps):
236
+ if self.interrupt:
237
+ continue
238
+ cur_condition_latents = mogle(condition_latents,latents,t.expand(1))
239
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
240
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
241
+
242
+ # handle guidance
243
+ if self.transformer.config.guidance_embeds:
244
+ guidance = torch.tensor([guidance_scale], device=device)
245
+ guidance = guidance.expand(latents.shape[0])
246
+ else:
247
+ guidance = None
248
+ noise_pred = tranformer_forward(
249
+ self.transformer,
250
+ model_config=model_config,
251
+ # Inputs of the condition (new feature)
252
+ condition_latents=cur_condition_latents if use_condition else None,
253
+ condition_ids=condition_ids if use_condition else None,
254
+ condition_type_ids=condition_type_ids if use_condition else None,
255
+ # Inputs to the original transformer
256
+ hidden_states=latents,
257
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
258
+ timestep=timestep / 1000,
259
+ guidance=guidance,
260
+ pooled_projections=pooled_prompt_embeds,
261
+ encoder_hidden_states=prompt_embeds,
262
+ txt_ids=text_ids,
263
+ img_ids=latent_image_ids,
264
+ joint_attention_kwargs=self.joint_attention_kwargs,
265
+ return_dict=False,
266
+ )[0]
267
+
268
+ # compute the previous noisy sample x_t -> x_t-1
269
+ latents_dtype = latents.dtype
270
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
271
+
272
+ if latents.dtype != latents_dtype:
273
+ if torch.backends.mps.is_available():
274
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
275
+ latents = latents.to(latents_dtype)
276
+
277
+ if callback_on_step_end is not None:
278
+ callback_kwargs = {}
279
+ for k in callback_on_step_end_tensor_inputs:
280
+ callback_kwargs[k] = locals()[k]
281
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
282
+
283
+ latents = callback_outputs.pop("latents", latents)
284
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
285
+
286
+ # call the callback, if provided
287
+ if i == len(timesteps) - 1 or (
288
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
289
+ ):
290
+ progress_bar.update()
291
+
292
+ if output_type == "latent":
293
+ image = latents
294
+
295
+ else:
296
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
297
+ latents = (
298
+ latents / self.vae.config.scaling_factor
299
+ ) + self.vae.config.shift_factor
300
+ image = self.vae.decode(latents, return_dict=False)[0]
301
+ image = self.image_processor.postprocess(image, output_type=output_type)
302
+
303
+ # Offload all models
304
+ self.maybe_free_model_hooks()
305
+
306
+ if condition_scale != 1:
307
+ for name, module in pipeline.transformer.named_modules():
308
+ if not name.endswith(".attn"):
309
+ continue
310
+ del module.c_factor
311
+
312
+ if not return_dict:
313
+ return (image,)
314
+
315
+ return FluxPipelineOutput(images=image)
316
+
src/flux/lora_controller.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ lora_module.scale_layer(0)
29
+
30
+ def __exit__(
31
+ self,
32
+ exc_type: Optional[Type[BaseException]],
33
+ exc_val: Optional[BaseException],
34
+ exc_tb: Optional[Any],
35
+ ) -> None:
36
+ if self.activated:
37
+ return
38
+ for i, lora_module in enumerate(self.lora_modules):
39
+ if not isinstance(lora_module, BaseTunerLayer):
40
+ continue
41
+ for active_adapter in lora_module.active_adapters:
42
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43
+
44
+
45
+ class set_lora_scale:
46
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47
+ self.lora_modules: List[BaseTunerLayer] = [
48
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
49
+ ]
50
+ self.scales = [
51
+ {
52
+ active_adapter: lora_module.scaling[active_adapter]
53
+ for active_adapter in lora_module.active_adapters
54
+ }
55
+ for lora_module in self.lora_modules
56
+ ]
57
+ self.scale = scale
58
+
59
+ def __enter__(self) -> None:
60
+ for lora_module in self.lora_modules:
61
+ if not isinstance(lora_module, BaseTunerLayer):
62
+ continue
63
+ lora_module.scale_layer(self.scale)
64
+
65
+ def __exit__(
66
+ self,
67
+ exc_type: Optional[Type[BaseException]],
68
+ exc_val: Optional[BaseException],
69
+ exc_tb: Optional[Any],
70
+ ) -> None:
71
+ for i, lora_module in enumerate(self.lora_modules):
72
+ if not isinstance(lora_module, BaseTunerLayer):
73
+ continue
74
+ for active_adapter in lora_module.active_adapters:
75
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
src/flux/pipeline_tools.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines import FluxPipeline
2
+ from diffusers.utils import logging
3
+ from diffusers.pipelines.flux.pipeline_flux import logger
4
+ from torch import Tensor
5
+
6
+
7
+ def encode_images(pipeline: FluxPipeline, images: Tensor):
8
+ images = pipeline.image_processor.preprocess(images)
9
+ images = images.to(pipeline.device).to(pipeline.dtype)
10
+ images = pipeline.vae.encode(images).latent_dist.sample()
11
+ images = (
12
+ images - pipeline.vae.config.shift_factor
13
+ ) * pipeline.vae.config.scaling_factor
14
+ images_tokens = pipeline._pack_latents(images, *images.shape)
15
+ images_ids = pipeline._prepare_latent_image_ids(
16
+ images.shape[0],
17
+ images.shape[2],
18
+ images.shape[3],
19
+ pipeline.device,
20
+ pipeline.dtype,
21
+ )
22
+ if images_tokens.shape[1] != images_ids.shape[0]:
23
+ images_ids = pipeline._prepare_latent_image_ids(
24
+ images.shape[0],
25
+ images.shape[2] // 2,
26
+ images.shape[3] // 2,
27
+ pipeline.device,
28
+ pipeline.dtype,
29
+ )
30
+ return images_tokens, images_ids
31
+
32
+
33
+ def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
34
+ # Turn off warnings (CLIP overflow)
35
+ logger.setLevel(logging.ERROR)
36
+ (
37
+ prompt_embeds,
38
+ pooled_prompt_embeds,
39
+ text_ids,
40
+ ) = pipeline.encode_prompt(
41
+ prompt=prompts,
42
+ prompt_2=None,
43
+ prompt_embeds=None,
44
+ pooled_prompt_embeds=None,
45
+ device=pipeline.device,
46
+ num_images_per_prompt=1,
47
+ max_sequence_length=max_sequence_length,
48
+ lora_scale=None,
49
+ )
50
+ # Turn on warnings
51
+ logger.setLevel(logging.WARNING)
52
+ return prompt_embeds, pooled_prompt_embeds, text_ids
src/flux/transformer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .block import block_forward, single_block_forward
5
+ from .lora_controller import enable_lora
6
+ from accelerate.utils import is_torch_version
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxTransformer2DModel,
9
+ Transformer2DModelOutput,
10
+ USE_PEFT_BACKEND,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ c_t=0,
54
+ **params: dict,
55
+ ):
56
+ self = transformer
57
+ use_condition = condition_latents is not None
58
+
59
+ (
60
+ hidden_states,
61
+ encoder_hidden_states,
62
+ pooled_projections,
63
+ timestep,
64
+ img_ids,
65
+ txt_ids,
66
+ guidance,
67
+ joint_attention_kwargs,
68
+ controlnet_block_samples,
69
+ controlnet_single_block_samples,
70
+ return_dict,
71
+ ) = prepare_params(**params)
72
+
73
+ if joint_attention_kwargs is not None:
74
+ joint_attention_kwargs = joint_attention_kwargs.copy()
75
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
76
+ else:
77
+ lora_scale = 1.0
78
+
79
+ if USE_PEFT_BACKEND:
80
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
81
+ scale_lora_layers(self, lora_scale)
82
+ else:
83
+ if (
84
+ joint_attention_kwargs is not None
85
+ and joint_attention_kwargs.get("scale", None) is not None
86
+ ):
87
+ logger.warning(
88
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
89
+ )
90
+
91
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
92
+ hidden_states = self.x_embedder(hidden_states)
93
+ # print("hidden states :",hidden_states.shape) hidden states : torch.Size([2, 1024, 3072])
94
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
95
+ # print(f"condition_latents shape {condition_latents.shape}") condition_latents shape torch.Size([2, 1024, 3072])
96
+
97
+ timestep = timestep.to(hidden_states.dtype) * 1000
98
+
99
+ if guidance is not None:
100
+ guidance = guidance.to(hidden_states.dtype) * 1000
101
+ else:
102
+ guidance = None
103
+
104
+ temb = (
105
+ self.time_text_embed(timestep, pooled_projections)
106
+ if guidance is None
107
+ else self.time_text_embed(timestep, guidance, pooled_projections)
108
+ )
109
+ # print(f"temb shape:{temb.shape}") torch.Size([2, 3072])
110
+
111
+ cond_temb = (
112
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
113
+ if guidance is None
114
+ else self.time_text_embed(
115
+ torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
116
+ )
117
+ )
118
+ # print("cond temb shape",cond_temb.shape) torch.Size([2, 3072])
119
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
120
+ # print(f"encoder hidden states {encoder_hidden_states.shape}") torch.Size([2, 512, 3072])
121
+
122
+ if txt_ids.ndim == 3:
123
+ logger.warning(
124
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
125
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
126
+ )
127
+ txt_ids = txt_ids[0]
128
+ if img_ids.ndim == 3:
129
+ logger.warning(
130
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
131
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
132
+ )
133
+ img_ids = img_ids[0]
134
+
135
+ ids = torch.cat((txt_ids, img_ids), dim=0) # 1536 3
136
+ image_rotary_emb = self.pos_embed(ids) # 2 1536 128
137
+
138
+ if use_condition:
139
+ # condition_ids[:, :1] = condition_type_ids
140
+ cond_rotary_emb = self.pos_embed(condition_ids) # 2 1536 128
141
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
142
+
143
+ for index_block, block in enumerate(self.transformer_blocks):
144
+ if self.training and self.gradient_checkpointing:
145
+ ckpt_kwargs: Dict[str, Any] = (
146
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
147
+ )
148
+ encoder_hidden_states, hidden_states, condition_latents = (
149
+ torch.utils.checkpoint.checkpoint(
150
+ block_forward,
151
+ self=block,
152
+ model_config=model_config,
153
+ hidden_states=hidden_states,
154
+ encoder_hidden_states=encoder_hidden_states,
155
+ condition_latents=condition_latents if use_condition else None,
156
+ temb=temb,
157
+ cond_temb=cond_temb if use_condition else None,
158
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
159
+ image_rotary_emb=image_rotary_emb,
160
+ **ckpt_kwargs,
161
+ )
162
+ )
163
+
164
+ else:
165
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
166
+ block,
167
+ model_config=model_config,
168
+ hidden_states=hidden_states,
169
+ encoder_hidden_states=encoder_hidden_states,
170
+ condition_latents=condition_latents if use_condition else None,
171
+ temb=temb,
172
+ cond_temb=cond_temb if use_condition else None,
173
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
174
+ image_rotary_emb=image_rotary_emb,
175
+ )
176
+
177
+ # controlnet residual
178
+ if controlnet_block_samples is not None:
179
+ interval_control = len(self.transformer_blocks) / len(
180
+ controlnet_block_samples
181
+ )
182
+ interval_control = int(np.ceil(interval_control))
183
+ hidden_states = (
184
+ hidden_states
185
+ + controlnet_block_samples[index_block // interval_control]
186
+ )
187
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
188
+
189
+ for index_block, block in enumerate(self.single_transformer_blocks):
190
+ if self.training and self.gradient_checkpointing:
191
+ ckpt_kwargs: Dict[str, Any] = (
192
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
193
+ )
194
+ result = torch.utils.checkpoint.checkpoint(
195
+ single_block_forward,
196
+ self=block,
197
+ model_config=model_config,
198
+ hidden_states=hidden_states,
199
+ temb=temb,
200
+ image_rotary_emb=image_rotary_emb,
201
+ **(
202
+ {
203
+ "condition_latents": condition_latents,
204
+ "cond_temb": cond_temb,
205
+ "cond_rotary_emb": cond_rotary_emb,
206
+ }
207
+ if use_condition
208
+ else {}
209
+ ),
210
+ **ckpt_kwargs,
211
+ )
212
+
213
+ else:
214
+ result = single_block_forward(
215
+ block,
216
+ model_config=model_config,
217
+ hidden_states=hidden_states,
218
+ temb=temb,
219
+ image_rotary_emb=image_rotary_emb,
220
+ **(
221
+ {
222
+ "condition_latents": condition_latents,
223
+ "cond_temb": cond_temb,
224
+ "cond_rotary_emb": cond_rotary_emb,
225
+ }
226
+ if use_condition
227
+ else {}
228
+ ),
229
+ )
230
+ if use_condition:
231
+ hidden_states, condition_latents = result
232
+ else:
233
+ hidden_states = result
234
+
235
+ # controlnet residual
236
+ if controlnet_single_block_samples is not None:
237
+ interval_control = len(self.single_transformer_blocks) / len(
238
+ controlnet_single_block_samples
239
+ )
240
+ interval_control = int(np.ceil(interval_control))
241
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
242
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
243
+ + controlnet_single_block_samples[index_block // interval_control]
244
+ )
245
+
246
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
247
+
248
+ hidden_states = self.norm_out(hidden_states, temb)
249
+ output = self.proj_out(hidden_states)
250
+ # print(f"output shape:{output.shape}")
251
+ if USE_PEFT_BACKEND:
252
+ # remove `lora_scale` from each PEFT layer
253
+ unscale_lora_layers(self, lora_scale)
254
+
255
+ if not return_dict:
256
+ return (output,)
257
+ return Transformer2DModelOutput(sample=output)
src/moe/__pycache__/mogle.cpython-311.pyc ADDED
Binary file (7.47 kB). View file
 
src/moe/mogle.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
4
+ import torch.optim as optim
5
+ from torch.nn import functional as F
6
+
7
+
8
+ # Define the Expert Network
9
+ class Expert(nn.Module):
10
+ def __init__(self, input_dim, hidden_dim, output_dim, use_softmax=False):
11
+ super(Expert, self).__init__()
12
+
13
+ self.use_softmax = use_softmax
14
+
15
+ self.net = nn.Sequential(
16
+ nn.Linear(input_dim, hidden_dim),
17
+ nn.ReLU(),
18
+ nn.Linear(hidden_dim, output_dim),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return (
23
+ self.net(x) if not self.use_softmax else torch.softmax(self.net(x), dim=1)
24
+ )
25
+
26
+
27
+ class DynamicGatingNetwork(nn.Module):
28
+ def __init__(self, hidden_dim=64, embed_dim=64, dtype=torch.bfloat16):
29
+ super().__init__()
30
+
31
+ # 处理时间步
32
+ self.time_proj = Timesteps(
33
+ hidden_dim, flip_sin_to_cos=True, downscale_freq_shift=0
34
+ )
35
+ self.timestep_embedding = TimestepEmbedding(hidden_dim, embed_dim)
36
+ self.timestep_embedding = self.timestep_embedding.to(dtype=torch.bfloat16)
37
+ # 处理 noise_latent
38
+ self.noise_proj = nn.Linear(hidden_dim, hidden_dim)
39
+ self.dtype = dtype
40
+
41
+ # 权重计算
42
+ self.gate = nn.Sequential(
43
+ nn.Linear(hidden_dim, hidden_dim),
44
+ nn.ReLU(),
45
+ nn.Linear(hidden_dim, 20), # 生成两个权重
46
+ )
47
+
48
+ def forward(self, condition_latents, noise_latent, timestep):
49
+ """
50
+ global_latents: (bs, 1024, 64)
51
+ noise_latent: (bs, 1024, 64)
52
+ timestep: (bs,)
53
+ """
54
+ bs, seq_len, hidden_dim = condition_latents.shape
55
+
56
+ # 处理 timestep
57
+ time_emb = self.time_proj(timestep) # (bs, hidden_dim)
58
+ time_emb = time_emb.to(self.dtype)
59
+ time_emb = self.timestep_embedding(time_emb) # (bs, embed_dim)
60
+
61
+ time_emb = time_emb.unsqueeze(1).expand(
62
+ -1, seq_len, -1
63
+ ) # (bs, 1024, embed_dim)
64
+
65
+ # 处理 noise_latent
66
+ noise_emb = self.noise_proj(noise_latent) # (bs, 1024, 64)
67
+ # 拼接所有输入
68
+ # fused_input = torch.cat([condition_latents, noise_emb, time_emb], dim=2) # (bs, 1024, 64+64+128)
69
+ fused_input = condition_latents + noise_emb + time_emb
70
+ # 计算权重
71
+ weight = self.gate(fused_input) # (bs, 1024, 2)
72
+ weight = F.softmax(weight, dim=2) # 归一化
73
+
74
+ return weight
75
+
76
+ class MoGLE(nn.Module):
77
+ def __init__(
78
+ self,
79
+ num_experts=20,
80
+ input_dim=64,
81
+ hidden_dim=32,
82
+ output_dim=64,
83
+ has_expert=True,
84
+ has_gating=True,
85
+ weight_is_scale=False,
86
+ ):
87
+ super().__init__()
88
+ expert_model = None
89
+ if has_expert:
90
+ expert_model = Expert
91
+ else:
92
+ expert_model = nn.Identity
93
+ self.global_expert = expert_model(
94
+ input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
95
+ )
96
+ self.local_experts = nn.ModuleList(
97
+ [
98
+ expert_model(
99
+ input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
100
+ )
101
+ for _ in range(num_experts - 1)
102
+ ]
103
+ )
104
+ # self.gating = Gating(input_dim=input_dim, num_experts=num_experts)
105
+ if has_gating:
106
+ self.gating = DynamicGatingNetwork()
107
+ else:
108
+ self.gating = nn.Identity()
109
+
110
+ self.weight_is_scale = weight_is_scale
111
+
112
+ def forward(self, x: torch.Tensor, noise_latent, timestep):
113
+ global_mask = x[:, 0] # bs 1024 64
114
+ local_mask = x[:, 1:] # bs 19 1024 64
115
+ if not isinstance(self.gating, nn.Identity):
116
+ weights = self.gating.forward(
117
+ global_mask, noise_latent=noise_latent, timestep=timestep
118
+ ) # bs 1024 20
119
+
120
+ _, num_local, h, w = local_mask.shape
121
+ global_output = self.global_expert(global_mask).unsqueeze(1)
122
+ local_outputs = torch.stack(
123
+ [self.local_experts[i](local_mask[:, i]) for i in range(num_local)], dim=1
124
+ ) # (bs, 19, 1024, 64)
125
+ global_local_outputs = torch.cat(
126
+ [global_output, local_outputs], dim=1
127
+ ) # bs 20 1024 64
128
+
129
+ if isinstance(self.gating, nn.Identity):
130
+ global_local_outputs = global_local_outputs.sum(dim=1)
131
+ return global_local_outputs
132
+ if self.weight_is_scale:
133
+ weights = torch.mean(weights, dim=1, keepdim=True) # bs 1 20
134
+ # print("gating scale")
135
+
136
+ weights_expanded = weights.unsqueeze(-1)
137
+ output = (global_local_outputs.permute(0, 2, 1, 3) * weights_expanded).sum(
138
+ dim=2
139
+ )
140
+ return output # bs 1024 64
src/train/callbacks.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from PIL import Image, ImageFilter, ImageDraw
3
+ import numpy as np
4
+ from transformers import pipeline
5
+ import cv2
6
+ import torch
7
+ import os
8
+ from torchvision import transforms as T
9
+ try:
10
+ import wandb
11
+ except ImportError:
12
+ wandb = None
13
+
14
+ from ..flux.condition import Condition
15
+ from ..flux.generate import generate
16
+
17
+
18
+ class FaceMoGLECallback(L.Callback):
19
+
20
+ def __init__(self, run_name, training_config: dict = {}):
21
+ self.run_name, self.training_config = run_name, training_config
22
+
23
+ self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
24
+ self.save_interval = training_config.get("save_interval", 1000)
25
+ self.sample_interval = training_config.get("sample_interval", 1000)
26
+ self.save_path = training_config.get("save_path", "./runs")
27
+
28
+ self.wandb_config = training_config.get("wandb", None)
29
+ self.use_wandb = (
30
+ wandb is not None and os.environ.get("WANDB_API_KEY") is not None
31
+ )
32
+
33
+ self.total_steps = 0
34
+
35
+ def to_tensor(self, x):
36
+ return T.ToTensor()(x)
37
+
38
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
39
+ gradient_size = 0
40
+ max_gradient_size = 0
41
+ count = 0
42
+ for _, param in pl_module.named_parameters():
43
+ if param.grad is not None:
44
+ gradient_size += param.grad.norm(2).item()
45
+ max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
46
+ count += 1
47
+ if count > 0:
48
+ gradient_size /= count
49
+
50
+ self.total_steps += 1
51
+
52
+ # Print training progress every n steps
53
+ if self.use_wandb:
54
+ report_dict = {
55
+ "steps": batch_idx,
56
+ "steps": self.total_steps,
57
+ "epoch": trainer.current_epoch,
58
+ "gradient_size": gradient_size,
59
+ }
60
+ loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
61
+ report_dict["loss"] = loss_value
62
+ report_dict["t"] = pl_module.last_t
63
+ wandb.log(report_dict)
64
+
65
+ if self.total_steps % self.print_every_n_steps == 0:
66
+ print(
67
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
68
+ )
69
+
70
+ # Save LoRA weights at specified intervals
71
+ if self.total_steps % self.save_interval == 0:
72
+ print(
73
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
74
+ )
75
+ pl_module.save_lora(
76
+ f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
77
+ )
78
+ if hasattr(pl_module, "save_moe"):
79
+ pl_module.save_moe(
80
+ f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}/moe.pt"
81
+ )
82
+
83
+ # Generate and save a sample image at specified intervals
84
+ if self.total_steps % self.sample_interval == 0:
85
+ print(
86
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
87
+ )
88
+ self.generate_a_sample(
89
+ trainer,
90
+ pl_module,
91
+ f"{self.save_path}/{self.run_name}/output",
92
+ f"lora_{self.total_steps}",
93
+ batch["condition_type"][
94
+ 0
95
+ ], # Use the condition type from the current batch
96
+ )
97
+
98
+
99
+ @torch.no_grad()
100
+ def generate_a_sample(
101
+ self,
102
+ trainer,
103
+ pl_module,
104
+ save_path,
105
+ file_name,
106
+ condition_type="super_resolution",
107
+ ):
108
+ # TODO: change this two variables to parameters
109
+ target_size = trainer.training_config["dataset"]["target_size"]
110
+ position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
111
+
112
+ generator = torch.Generator(device=pl_module.device)
113
+ generator.manual_seed(42)
114
+
115
+ test_list = []
116
+
117
+ condition_img_path = "data/mmcelebahq/mask/27000.png"
118
+
119
+ # condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
120
+ test_list.append(
121
+ (
122
+ condition_img_path,
123
+ [0, 0],
124
+ "She is wearing lipstick. She is attractive and has straight hair.",
125
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
126
+ )
127
+ )
128
+
129
+
130
+ if not os.path.exists(save_path):
131
+ os.makedirs(save_path)
132
+ for i, (condition_img_path, position_delta, prompt, *others) in enumerate(
133
+ test_list
134
+ ):
135
+
136
+ global_mask = Image.open(condition_img_path).convert("RGB")
137
+ mask_list = [self.to_tensor(global_mask)]
138
+ mask = Image.open(condition_img_path)
139
+ mask = np.array(mask)
140
+ for i in range(19):
141
+ local_mask = np.zeros_like(mask)
142
+ local_mask[mask == i] = 255
143
+
144
+ local_mask_rgb = Image.fromarray(local_mask).convert("RGB")
145
+ local_mask_tensor = self.to_tensor(local_mask_rgb)
146
+ mask_list.append(local_mask_tensor)
147
+ condition_img = torch.stack(mask_list, dim=0)
148
+ # condition_img = condition_img.unsqueeze(0)
149
+
150
+ condition = Condition(
151
+ condition_type=condition_type,
152
+ condition=condition_img,
153
+ position_delta=position_delta,
154
+ **(others[0] if others else {}),
155
+ )
156
+
157
+ res = generate(
158
+ pl_module.flux_pipe,
159
+ mogle=pl_module.mogle,
160
+ prompt=prompt,
161
+ conditions=[condition],
162
+ height=target_size,
163
+ width=target_size,
164
+ generator=generator,
165
+ model_config=pl_module.model_config,
166
+ default_lora=True,
167
+ )
168
+ res.images[0].save(
169
+ os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
170
+ )
src/train/data.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import numpy as np
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as T
6
+ import random
7
+ import torch
8
+ import json
9
+
10
+
11
+ class MMCelebAHQ(Dataset):
12
+ def __init__(
13
+ self,
14
+ root="data/mmcelebahq",
15
+ condition_size: int = 512,
16
+ target_size: int = 512,
17
+ condition_type: str = "depth",
18
+ drop_text_prob: float = 0.1,
19
+ drop_image_prob: float = 0.1,
20
+ return_pil_image: bool = False,
21
+ position_scale=1.0,
22
+ ):
23
+ self.root = root
24
+ self.face_paths, self.mask_paths, self.prompts = self.get_face_mask_prompt()
25
+ self.condition_size = condition_size
26
+ self.target_size = target_size
27
+ self.condition_type = condition_type
28
+ self.drop_text_prob = drop_text_prob
29
+ self.drop_image_prob = drop_image_prob
30
+ self.return_pil_image = return_pil_image
31
+ self.position_scale = position_scale
32
+
33
+ self.to_tensor = T.ToTensor()
34
+
35
+ def get_face_mask_prompt(self):
36
+ face_paths = [
37
+ os.path.join(self.root, "face", f"{i}.jpg") for i in range(0, 27000)
38
+ ]
39
+ mask_paths = [
40
+ os.path.join(self.root, "mask", f"{i}.png") for i in range(0, 27000)
41
+ ]
42
+ with open(os.path.join(self.root, "text.json"), mode="r") as f:
43
+ prompts = json.load(f)
44
+ return face_paths, mask_paths, prompts
45
+
46
+ def __len__(self):
47
+ return len(self.face_paths)
48
+
49
+ def __getitem__(self, idx):
50
+ image = Image.open(self.face_paths[idx]).convert("RGB")
51
+ prompts = self.prompts[f"{idx}.jpg"]
52
+ description = random.choices(prompts, k=1)[0].strip()
53
+ enable_scale = random.random() < 1
54
+ if not enable_scale:
55
+ condition_size = int(self.condition_size * self.position_scale)
56
+ position_scale = 1.0
57
+ else:
58
+ condition_size = self.condition_size
59
+ position_scale = self.position_scale
60
+
61
+ # Get the condition image
62
+ position_delta = np.array([0, 0])
63
+
64
+ mask = np.array(Image.open(self.mask_paths[idx]))
65
+ mask_list = [self.to_tensor(Image.open(self.mask_paths[idx]).convert("RGB"))]
66
+ for i in range(19):
67
+ local_mask = np.zeros_like(mask)
68
+ local_mask[mask == i] = 255
69
+
70
+ drop_image = random.random() < self.drop_image_prob
71
+ if drop_image:
72
+ local_mask = np.zeros_like(mask)
73
+
74
+ local_mask_rgb = Image.fromarray(local_mask).convert("RGB")
75
+ local_mask_tensor = self.to_tensor(local_mask_rgb)
76
+ mask_list.append(local_mask_tensor)
77
+ condition_img = torch.stack(mask_list,dim=0)
78
+
79
+
80
+ # Randomly drop text or image
81
+ drop_text = random.random() < self.drop_text_prob
82
+ # drop_image = random.random() < self.drop_image_prob
83
+ if drop_text:
84
+ description = ""
85
+
86
+ return {
87
+ "image": self.to_tensor(image),
88
+ "condition": condition_img,
89
+ # "condition": self.to_tensor(condition_img),
90
+ "condition_type": self.condition_type,
91
+ "description": description,
92
+ "position_delta": position_delta,
93
+ **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
94
+ **({"position_scale": position_scale} if position_scale != 1.0 else {}),
95
+ }
96
+
97
+
98
+
src/train/model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from diffusers.pipelines import FluxPipeline
3
+ import torch
4
+ from peft import LoraConfig, get_peft_model_state_dict
5
+
6
+ import prodigyopt
7
+ import os
8
+ from ..flux.transformer import tranformer_forward
9
+ from ..flux.condition import Condition
10
+ from ..flux.pipeline_tools import encode_images, prepare_text_input
11
+
12
+ from ..moe.mogle import MoGLE
13
+
14
+
15
+ class FaceMoGLE(L.LightningModule):
16
+ def __init__(
17
+ self,
18
+ flux_pipe_id: str,
19
+ lora_path: str = None,
20
+ lora_config: dict = None,
21
+ device: str = "cuda",
22
+ dtype: torch.dtype = torch.bfloat16,
23
+ model_config: dict = {},
24
+ optimizer_config: dict = None,
25
+ gradient_checkpointing: bool = False,
26
+ has_expert=True,
27
+ has_gating=True,
28
+ weight_is_scale=False
29
+ ):
30
+ # Initialize the LightningModule
31
+ super().__init__()
32
+ self.model_config = model_config
33
+ self.optimizer_config = optimizer_config
34
+
35
+ # Load the Flux pipeline
36
+ self.flux_pipe: FluxPipeline = (
37
+ FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
38
+ )
39
+ self.transformer = self.flux_pipe.transformer
40
+ self.transformer.gradient_checkpointing = gradient_checkpointing
41
+ self.transformer.train()
42
+ self.mogle = MoGLE(has_expert=has_expert,has_gating=has_gating,weight_is_scale=weight_is_scale)
43
+ self.mogle.train()
44
+ # Freeze the Flux pipeline
45
+ self.flux_pipe.text_encoder.requires_grad_(False).eval()
46
+ self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
47
+ self.flux_pipe.vae.requires_grad_(False).eval()
48
+
49
+ # Initialize LoRA layers
50
+ self.lora_layers = self.init_lora(lora_path, lora_config)
51
+
52
+ self.to(device).to(dtype)
53
+
54
+ def init_lora(self, lora_path: str, lora_config: dict):
55
+ assert lora_path or lora_config
56
+ if lora_path:
57
+ # TODO: Implement this
58
+ raise NotImplementedError
59
+ else:
60
+ self.transformer.add_adapter(LoraConfig(**lora_config))
61
+ # TODO: Check if this is correct (p.requires_grad)
62
+ lora_layers = filter(
63
+ lambda p: p.requires_grad, self.transformer.parameters()
64
+ )
65
+ return list(lora_layers)
66
+
67
+ def save_lora(self, path: str):
68
+ FluxPipeline.save_lora_weights(
69
+ save_directory=path,
70
+ transformer_lora_layers=get_peft_model_state_dict(self.transformer),
71
+ safe_serialization=True,
72
+ )
73
+ torch.save(self.mogle.state_dict(), os.path.join(path, "mogle.pt"))
74
+
75
+
76
+ def configure_optimizers(self):
77
+ # Freeze the transformer
78
+ self.transformer.requires_grad_(False)
79
+ opt_config = self.optimizer_config
80
+
81
+ # Set the trainable parameters
82
+ self.trainable_params = self.lora_layers + [p for p in self.mogle.parameters()]
83
+
84
+ # Unfreeze trainable parameters
85
+ for p in self.trainable_params:
86
+ p.requires_grad_(True)
87
+
88
+ # Initialize the optimizer
89
+ if opt_config["type"] == "AdamW":
90
+ optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
91
+ elif opt_config["type"] == "Prodigy":
92
+ optimizer = prodigyopt.Prodigy(
93
+ self.trainable_params,
94
+ **opt_config["params"],
95
+ )
96
+ elif opt_config["type"] == "SGD":
97
+ optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
98
+ else:
99
+ raise NotImplementedError
100
+
101
+ return optimizer
102
+
103
+ def training_step(self, batch, batch_idx):
104
+ step_loss = self.step(batch)
105
+ self.log_loss = (
106
+ step_loss.item()
107
+ if not hasattr(self, "log_loss")
108
+ else self.log_loss * 0.95 + step_loss.item() * 0.05
109
+ )
110
+ return step_loss
111
+
112
+ def step(self, batch):
113
+ imgs = batch["image"]
114
+ conditions = batch["condition"] # bsx20x3x512x512
115
+ condition_types = batch["condition_type"]
116
+ prompts = batch["description"]
117
+ position_delta = batch["position_delta"][0]
118
+ position_scale = float(batch.get("position_scale", [1.0])[0])
119
+
120
+ # Prepare inputs
121
+ with torch.no_grad():
122
+ # Prepare image input
123
+ x_0, img_ids = encode_images(self.flux_pipe, imgs)
124
+
125
+ # Prepare text input
126
+ prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
127
+ self.flux_pipe, prompts
128
+ )
129
+
130
+ # Prepare t and x_t
131
+ t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
132
+ x_1 = torch.randn_like(x_0).to(self.device)
133
+ t_ = t.unsqueeze(1).unsqueeze(1)
134
+ x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
135
+
136
+ # Prepare conditions # condition_latents \in bsx64x32x32 -> bsx(32x32)x64, condition_ids \in [1024, 3]
137
+ # intial conditions shape [bs, 19, 3, 512, 512] reshape to [bsx19, 3, 512, 512]
138
+ c_bs, c_classes, c_channels, c_h, c_w = conditions.shape
139
+ conditions = conditions.view(c_bs * c_classes, c_channels, c_h, c_w)
140
+
141
+ condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
142
+ condition_latents_reshape = condition_latents.reshape(c_bs, c_classes, *condition_latents.shape[-2:]) # bs 20 1024 64
143
+ condition_latents = self.mogle.forward(condition_latents_reshape,noise_latent=x_t,timestep=t)
144
+ # conditions shape [bsx19, 1024, 64] # this is condition features
145
+ # condition_ids shape [1024, 3] # this is position embedding
146
+ # help me design a simple MoE to fuse 19 condition_latents
147
+
148
+
149
+ # Add position delta
150
+ condition_ids[:, 1] += position_delta[0]
151
+ condition_ids[:, 2] += position_delta[1]
152
+
153
+ if position_scale != 1.0:
154
+ scale_bias = (position_scale - 1.0) / 2
155
+ condition_ids[:, 1] *= position_scale
156
+ condition_ids[:, 2] *= position_scale
157
+ condition_ids[:, 1] += scale_bias
158
+ condition_ids[:, 2] += scale_bias
159
+
160
+ # Prepare condition type
161
+ condition_type_ids = torch.tensor(
162
+ [
163
+ Condition.get_type_id(condition_type)
164
+ for condition_type in condition_types
165
+ ]
166
+ ).to(self.device)
167
+ condition_type_ids = (
168
+ torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
169
+ ).unsqueeze(1)
170
+
171
+ # Prepare guidance
172
+ guidance = (
173
+ torch.ones_like(t).to(self.device)
174
+ if self.transformer.config.guidance_embeds
175
+ else None
176
+ )
177
+ # Forward pass
178
+ transformer_out = tranformer_forward(
179
+ self.transformer,
180
+ # Model config
181
+ model_config=self.model_config,
182
+ # Inputs of the condition (new feature)
183
+ condition_latents=condition_latents,
184
+ condition_ids=condition_ids,
185
+ condition_type_ids=condition_type_ids,
186
+ # Inputs to the original transformer
187
+ hidden_states=x_t,
188
+ timestep=t,
189
+ guidance=guidance,
190
+ pooled_projections=pooled_prompt_embeds,
191
+ encoder_hidden_states=prompt_embeds,
192
+ txt_ids=text_ids,
193
+ img_ids=img_ids,
194
+ joint_attention_kwargs=None,
195
+ return_dict=False,
196
+ )
197
+ pred = transformer_out[0]
198
+ # Compute loss
199
+ loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
200
+ self.last_t = t.mean().item()
201
+ return loss
weights/mogle.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b071a349d1e8f922d32a066014f9cc80b39f8db55043d8bdf04e79e156d4f243
3
+ size 238252
weights/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b2202f249a33252ce4f630db2f9536a28caf4b90e27927633f1f3bbb121f774
3
+ size 29066872