marks
commited on
Commit
·
3778bc0
1
Parent(s):
002c192
Init
Browse files- .gitignore +17 -0
- Dockerfile +30 -0
- LICENSE +201 -0
- api.py +122 -0
- app.py +167 -4
- configs/config-dev-1-RTX6000ADA-Copy1.json +57 -0
- configs/config-dev-1-RTX6000ADA.json +57 -0
- configs/config-dev-cuda0.json +56 -0
- configs/config-dev-eval.json +57 -0
- configs/config-dev-gigaquant.json +58 -0
- configs/config-dev-offload-1-4080.json +58 -0
- configs/config-dev-offload-1-4090.json +58 -0
- configs/config-dev-offload.json +58 -0
- configs/config-dev-prequant.json +57 -0
- configs/config-dev.json +57 -0
- configs/config-f8.json +48 -0
- configs/config-schnell-cuda0.json +57 -0
- configs/config-schnell.json +57 -0
- dark.safetensors +3 -0
- f8.json +48 -0
- float8_quantize.py +496 -0
- flux_emphasis.py +447 -0
- flux_pipeline.py +729 -0
- image_encoder.py +35 -0
- lora_loading.py +753 -0
- main.py +199 -0
- modules/autoencoder.py +336 -0
- modules/conditioner.py +128 -0
- modules/flux_model.py +734 -0
- photo.safetensors +3 -0
- start.py +0 -0
- util.py +333 -0
.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.jpg
|
3 |
+
*.png
|
4 |
+
*.jpeg
|
5 |
+
*.gif
|
6 |
+
*.bmp
|
7 |
+
*.webp
|
8 |
+
*.mp4
|
9 |
+
*.mp3
|
10 |
+
*.mp3
|
11 |
+
*.txt
|
12 |
+
.copilotignore
|
13 |
+
.misc
|
14 |
+
BFL-flux-diffusers
|
15 |
+
.env
|
16 |
+
.env.*
|
17 |
+
perfection.safetensors
|
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base image with Python 3.11, PyTorch, CUDA 12.4.1, and Ubuntu 22.04
|
2 |
+
FROM runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04
|
3 |
+
|
4 |
+
# Set the working directory inside the container
|
5 |
+
WORKDIR /workspace
|
6 |
+
|
7 |
+
# Install necessary packages and dependencies
|
8 |
+
RUN apt-get update && apt-get install -y \
|
9 |
+
wget \
|
10 |
+
git \
|
11 |
+
&& rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
# Clone your repository
|
14 |
+
RUN git clone https://github.com/Yuanshi9815/OminiControl
|
15 |
+
|
16 |
+
# Change directory to the cloned repo
|
17 |
+
WORKDIR /workspace/fp8
|
18 |
+
|
19 |
+
# Install Python dependencies
|
20 |
+
RUN pip install -r requirements.txt
|
21 |
+
|
22 |
+
# Download the required model files
|
23 |
+
RUN wget https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors -O /workspace/flux1-schnell.safetensors && \
|
24 |
+
wget https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors -O /workspace/ae.safetensors
|
25 |
+
|
26 |
+
# Expose necessary HTTP ports
|
27 |
+
EXPOSE 8888 7860
|
28 |
+
|
29 |
+
# Set the command to run your Python script
|
30 |
+
CMD ["python", "main_gr.py"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2024 Alex Redden
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
api.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional, TYPE_CHECKING
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
from platform import system
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from flux_pipeline import FluxPipeline
|
11 |
+
|
12 |
+
if system() == "Windows":
|
13 |
+
MAX_RAND = 2**16 - 1
|
14 |
+
else:
|
15 |
+
MAX_RAND = 2**32 - 1
|
16 |
+
|
17 |
+
|
18 |
+
class AppState:
|
19 |
+
model: "FluxPipeline"
|
20 |
+
|
21 |
+
|
22 |
+
class FastAPIApp(FastAPI):
|
23 |
+
state: AppState
|
24 |
+
|
25 |
+
|
26 |
+
class LoraArgs(BaseModel):
|
27 |
+
scale: Optional[float] = 1.0
|
28 |
+
path: Optional[str] = None
|
29 |
+
name: Optional[str] = None
|
30 |
+
action: Optional[Literal["load", "unload"]] = "load"
|
31 |
+
|
32 |
+
|
33 |
+
class LoraLoadResponse(BaseModel):
|
34 |
+
status: Literal["success", "error"]
|
35 |
+
message: Optional[str] = None
|
36 |
+
|
37 |
+
|
38 |
+
class GenerateArgs(BaseModel):
|
39 |
+
prompt: str
|
40 |
+
width: Optional[int] = Field(default=720)
|
41 |
+
height: Optional[int] = Field(default=1024)
|
42 |
+
num_steps: Optional[int] = Field(default=24)
|
43 |
+
guidance: Optional[float] = Field(default=3.5)
|
44 |
+
seed: Optional[int] = Field(
|
45 |
+
default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND
|
46 |
+
)
|
47 |
+
strength: Optional[float] = 1.0
|
48 |
+
init_image: Optional[str] = None
|
49 |
+
|
50 |
+
|
51 |
+
app = FastAPIApp()
|
52 |
+
|
53 |
+
|
54 |
+
@app.post("/generate")
|
55 |
+
def generate(args: GenerateArgs):
|
56 |
+
"""
|
57 |
+
Generates an image from the Flux flow transformer.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
args (GenerateArgs): Arguments for image generation:
|
61 |
+
|
62 |
+
- `prompt`: The prompt used for image generation.
|
63 |
+
|
64 |
+
- `width`: The width of the image.
|
65 |
+
|
66 |
+
- `height`: The height of the image.
|
67 |
+
|
68 |
+
- `num_steps`: The number of steps for the image generation.
|
69 |
+
|
70 |
+
- `guidance`: The guidance for image generation, represents the
|
71 |
+
influence of the prompt on the image generation.
|
72 |
+
|
73 |
+
- `seed`: The seed for the image generation.
|
74 |
+
|
75 |
+
- `strength`: strength for image generation, 0.0 - 1.0.
|
76 |
+
Represents the percent of diffusion steps to run,
|
77 |
+
setting the init_image as the noised latent at the
|
78 |
+
given number of steps.
|
79 |
+
|
80 |
+
- `init_image`: Base64 encoded image or path to image to use as the init image.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
StreamingResponse: The generated image as streaming jpeg bytes.
|
84 |
+
"""
|
85 |
+
result = app.state.model.generate(**args.model_dump())
|
86 |
+
return StreamingResponse(result, media_type="image/jpeg")
|
87 |
+
|
88 |
+
|
89 |
+
@app.post("/lora", response_model=LoraLoadResponse)
|
90 |
+
def lora_action(args: LoraArgs):
|
91 |
+
"""
|
92 |
+
Loads or unloads a LoRA checkpoint into / from the Flux flow transformer.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
args (LoraArgs): Arguments for the LoRA action:
|
96 |
+
|
97 |
+
- `scale`: The scaling factor for the LoRA weights.
|
98 |
+
- `path`: The path to the LoRA checkpoint.
|
99 |
+
- `name`: The name of the LoRA checkpoint.
|
100 |
+
- `action`: The action to perform, either "load" or "unload".
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
LoraLoadResponse: The status of the LoRA action.
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
if args.action == "load":
|
107 |
+
app.state.model.load_lora(args.path, args.scale, args.name)
|
108 |
+
elif args.action == "unload":
|
109 |
+
app.state.model.unload_lora(args.name if args.name else args.path)
|
110 |
+
else:
|
111 |
+
return JSONResponse(
|
112 |
+
content={
|
113 |
+
"status": "error",
|
114 |
+
"message": f"Invalid action, expected 'load' or 'unload', got {args.action}",
|
115 |
+
},
|
116 |
+
status_code=400,
|
117 |
+
)
|
118 |
+
except Exception as e:
|
119 |
+
return JSONResponse(
|
120 |
+
status_code=500, content={"status": "error", "message": str(e)}
|
121 |
+
)
|
122 |
+
return JSONResponse(status_code=200, content={"status": "success"})
|
app.py
CHANGED
@@ -1,7 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
-
def
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import spaces
|
4 |
+
import torch
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
from flux_pipeline import FluxPipeline
|
7 |
import gradio as gr
|
8 |
+
from PIL import Image
|
9 |
|
10 |
+
def download_models():
|
11 |
+
"""
|
12 |
+
Download required models at application startup using wget.
|
13 |
+
"""
|
14 |
+
model_urls = [
|
15 |
+
"https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/flux1-dev.safetensors",
|
16 |
+
"https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/ae.safetensors",
|
17 |
+
]
|
18 |
+
for url in model_urls:
|
19 |
+
filename = url.split("/")[-1]
|
20 |
+
if not os.path.exists(filename):
|
21 |
+
print(f"Downloading {filename}...")
|
22 |
+
subprocess.run(["wget", "-O", filename, url], check=True)
|
23 |
+
else:
|
24 |
+
print(f"{filename} already exists, skipping download.")
|
25 |
|
26 |
+
print("All models are ready.")
|
27 |
+
|
28 |
+
|
29 |
+
def load_sft(ckpt_path, device="cpu"):
|
30 |
+
"""
|
31 |
+
Load a safetensors file.
|
32 |
+
Args:
|
33 |
+
ckpt_path (str): Local path to the safetensors file.
|
34 |
+
device (str): Device to load the file onto.
|
35 |
+
Returns:
|
36 |
+
Safetensors model state dictionary.
|
37 |
+
"""
|
38 |
+
if os.path.exists(ckpt_path):
|
39 |
+
print(f"Loading local checkpoint: {ckpt_path}")
|
40 |
+
return load_file(ckpt_path, device=device)
|
41 |
+
else:
|
42 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
43 |
+
|
44 |
+
|
45 |
+
def create_demo(config_path: str):
|
46 |
+
generator = FluxPipeline.load_pipeline_from_config_path(config_path)
|
47 |
+
load_sft("photo.safetenors", "cuda")
|
48 |
+
load_sft("dark.safetensors", "cuda")
|
49 |
+
load_sft("perfection.safetensors", "cuda")
|
50 |
+
@spaces.GPU
|
51 |
+
def generate_image(
|
52 |
+
prompt,
|
53 |
+
width,
|
54 |
+
height,
|
55 |
+
num_steps,
|
56 |
+
guidance,
|
57 |
+
seed,
|
58 |
+
init_image,
|
59 |
+
image2image_strength,
|
60 |
+
add_sampling_metadata,
|
61 |
+
):
|
62 |
+
|
63 |
+
seed = int(seed)
|
64 |
+
if seed == -1:
|
65 |
+
seed = None
|
66 |
+
out = generator.generate(
|
67 |
+
prompt,
|
68 |
+
width,
|
69 |
+
height,
|
70 |
+
num_steps=num_steps,
|
71 |
+
guidance=guidance,
|
72 |
+
seed=seed,
|
73 |
+
init_image=init_image,
|
74 |
+
strength=image2image_strength,
|
75 |
+
silent=False,
|
76 |
+
num_images=1,
|
77 |
+
return_seed=True,
|
78 |
+
)
|
79 |
+
image_bytes = out[0]
|
80 |
+
return Image.open(image_bytes), str(out[1]), None
|
81 |
+
|
82 |
+
is_schnell = generator.config.version == "flux-schnell"
|
83 |
+
|
84 |
+
with gr.Blocks() as demo:
|
85 |
+
gr.Markdown(f"# Flux Image Generation Demo - Model: {generator.config.version}")
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
with gr.Column():
|
89 |
+
prompt = gr.Textbox(
|
90 |
+
label="Prompt",
|
91 |
+
value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
|
92 |
+
)
|
93 |
+
do_img2img = gr.Checkbox(
|
94 |
+
label="Image to Image", value=False, interactive=not is_schnell
|
95 |
+
)
|
96 |
+
init_image = gr.Image(label="Input Image", visible=False)
|
97 |
+
image2image_strength = gr.Slider(
|
98 |
+
0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
|
99 |
+
)
|
100 |
+
|
101 |
+
with gr.Accordion("Advanced Options", open=False):
|
102 |
+
width = gr.Slider(128, 8192, 1152, step=16, label="Width")
|
103 |
+
height = gr.Slider(128, 8192, 640, step=16, label="Height")
|
104 |
+
num_steps = gr.Slider(
|
105 |
+
1, 50, 4 if is_schnell else 20, step=1, label="Number of steps"
|
106 |
+
)
|
107 |
+
guidance = gr.Slider(
|
108 |
+
1.0,
|
109 |
+
10.0,
|
110 |
+
3.5,
|
111 |
+
step=0.1,
|
112 |
+
label="Guidance",
|
113 |
+
interactive=not is_schnell,
|
114 |
+
)
|
115 |
+
seed = gr.Textbox(-1, label="Seed (-1 for random)")
|
116 |
+
add_sampling_metadata = gr.Checkbox(
|
117 |
+
label="Add sampling parameters to metadata?", value=True
|
118 |
+
)
|
119 |
+
|
120 |
+
generate_btn = gr.Button("Generate")
|
121 |
+
|
122 |
+
with gr.Column(min_width="960px"):
|
123 |
+
output_image = gr.Image(label="Generated Image")
|
124 |
+
seed_output = gr.Number(label="Used Seed")
|
125 |
+
warning_text = gr.Textbox(label="Warning", visible=False)
|
126 |
+
|
127 |
+
def update_img2img(do_img2img):
|
128 |
+
return {
|
129 |
+
init_image: gr.update(visible=do_img2img),
|
130 |
+
image2image_strength: gr.update(visible=do_img2img),
|
131 |
+
}
|
132 |
+
|
133 |
+
do_img2img.change(
|
134 |
+
update_img2img, do_img2img, [init_image, image2image_strength]
|
135 |
+
)
|
136 |
+
|
137 |
+
generate_btn.click(
|
138 |
+
fn=generate_image,
|
139 |
+
inputs=[
|
140 |
+
prompt,
|
141 |
+
width,
|
142 |
+
height,
|
143 |
+
num_steps,
|
144 |
+
guidance,
|
145 |
+
seed,
|
146 |
+
init_image,
|
147 |
+
image2image_strength,
|
148 |
+
add_sampling_metadata,
|
149 |
+
],
|
150 |
+
outputs=[output_image, seed_output, warning_text],
|
151 |
+
)
|
152 |
+
|
153 |
+
return demo
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
import argparse
|
158 |
+
|
159 |
+
parser = argparse.ArgumentParser(description="Flux")
|
160 |
+
parser.add_argument(
|
161 |
+
"--config", type=str, default="configs/config-dev-1-RTX6000ADA.json", help="Config file path"
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--share", action="store_true", help="Create a public link to your demo"
|
165 |
+
)
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
|
169 |
+
demo = create_demo(args.config)
|
170 |
+
demo.launch(share=args.share)
|
configs/config-dev-1-RTX6000ADA-Copy1.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": true,
|
53 |
+
"compile_blocks": true,
|
54 |
+
"offload_text_encoder": false,
|
55 |
+
"offload_vae": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-dev-1-RTX6000ADA.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": true,
|
53 |
+
"compile_blocks": true,
|
54 |
+
"offload_text_encoder": false,
|
55 |
+
"offload_vae": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-dev-cuda0.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"text_enc_quantization_dtype": "qfloat8",
|
51 |
+
"compile_extras": false,
|
52 |
+
"compile_blocks": false,
|
53 |
+
"offload_ae": false,
|
54 |
+
"offload_text_enc": false,
|
55 |
+
"offload_flow": false
|
56 |
+
}
|
configs/config-dev-eval.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:1",
|
45 |
+
"ae_device": "cuda:1",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": false,
|
53 |
+
"compile_blocks": false,
|
54 |
+
"offload_ae": false,
|
55 |
+
"offload_text_enc": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-dev-gigaquant.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"num_to_quant": 220,
|
51 |
+
"flow_quantization_dtype": "qint4",
|
52 |
+
"text_enc_quantization_dtype": "qint4",
|
53 |
+
"ae_quantization_dtype": "qint4",
|
54 |
+
"clip_quantization_dtype": "qint4",
|
55 |
+
"compile_extras": false,
|
56 |
+
"compile_blocks": false,
|
57 |
+
"quantize_extras": true
|
58 |
+
}
|
configs/config-dev-offload-1-4080.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qint4",
|
52 |
+
"ae_quantization_dtype": "qfloat8",
|
53 |
+
"compile_extras": true,
|
54 |
+
"compile_blocks": true,
|
55 |
+
"offload_text_encoder": true,
|
56 |
+
"offload_vae": true,
|
57 |
+
"offload_flow": true
|
58 |
+
}
|
configs/config-dev-offload-1-4090.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qint4",
|
52 |
+
"ae_quantization_dtype": "qfloat8",
|
53 |
+
"compile_extras": true,
|
54 |
+
"compile_blocks": true,
|
55 |
+
"offload_text_encoder": true,
|
56 |
+
"offload_vae": true,
|
57 |
+
"offload_flow": false
|
58 |
+
}
|
configs/config-dev-offload.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qint4",
|
52 |
+
"ae_quantization_dtype": "qfloat8",
|
53 |
+
"compile_extras": false,
|
54 |
+
"compile_blocks": false,
|
55 |
+
"offload_text_encoder": true,
|
56 |
+
"offload_vae": true,
|
57 |
+
"offload_flow": true
|
58 |
+
}
|
configs/config-dev-prequant.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/flux-fp16-acc/flux_fp8.safetensors",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:1",
|
45 |
+
"ae_device": "cuda:1",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"text_enc_quantization_dtype": "qfloat8",
|
51 |
+
"compile_extras": false,
|
52 |
+
"compile_blocks": false,
|
53 |
+
"prequantized_flow": true,
|
54 |
+
"offload_ae": false,
|
55 |
+
"offload_text_enc": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-dev.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:1",
|
45 |
+
"ae_device": "cuda:1",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"text_enc_quantization_dtype": "qfloat8",
|
51 |
+
"ae_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": true,
|
53 |
+
"compile_blocks": true,
|
54 |
+
"offload_ae": false,
|
55 |
+
"offload_text_enc": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-f8.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-schnell",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [16, 56, 56],
|
13 |
+
"theta": 10000,
|
14 |
+
"qkv_bias": true,
|
15 |
+
"guidance_embed": false
|
16 |
+
},
|
17 |
+
"ae_params": {
|
18 |
+
"resolution": 256,
|
19 |
+
"in_channels": 3,
|
20 |
+
"ch": 128,
|
21 |
+
"out_ch": 3,
|
22 |
+
"ch_mult": [1, 2, 4, 4],
|
23 |
+
"num_res_blocks": 2,
|
24 |
+
"z_channels": 16,
|
25 |
+
"scale_factor": 0.3611,
|
26 |
+
"shift_factor": 0.1159
|
27 |
+
},
|
28 |
+
"ckpt_path": "flux1-schnell.safetensors",
|
29 |
+
"ae_path": "ae.safetensors",
|
30 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
31 |
+
"repo_flow": "flux1-dev.sft",
|
32 |
+
"repo_ae": "ae.sft",
|
33 |
+
"text_enc_max_length": 256,
|
34 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
35 |
+
"text_enc_device": "cuda:0",
|
36 |
+
"ae_device": "cuda:0",
|
37 |
+
"flux_device": "cuda:0",
|
38 |
+
"flow_dtype": "float16",
|
39 |
+
"ae_dtype": "bfloat16",
|
40 |
+
"text_enc_dtype": "bfloat16",
|
41 |
+
"flow_quantization_dtype": "qfloat8",
|
42 |
+
"text_enc_quantization_dtype": "qfloat8",
|
43 |
+
"compile_extras": true,
|
44 |
+
"compile_blocks": true,
|
45 |
+
"offload_text_encoder": false,
|
46 |
+
"offload_vae": false,
|
47 |
+
"offload_flow": false
|
48 |
+
}
|
configs/config-schnell-cuda0.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-schnell",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": false
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-schnell",
|
40 |
+
"repo_flow": "flux1-schnell.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 256,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"text_enc_quantization_dtype": "qfloat8",
|
51 |
+
"ae_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": false,
|
53 |
+
"compile_blocks": false,
|
54 |
+
"offload_ae": false,
|
55 |
+
"offload_text_enc": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-schnell.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-schnell",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": false
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-schnell",
|
40 |
+
"repo_flow": "flux1-schnell.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 256,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:1",
|
45 |
+
"ae_device": "cuda:1",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"text_enc_quantization_dtype": "qfloat8",
|
51 |
+
"ae_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": true,
|
53 |
+
"compile_blocks": true,
|
54 |
+
"offload_ae": false,
|
55 |
+
"offload_text_enc": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
dark.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c633fc7d1af2452f0680abdc20baa285c43e107ae9a32fbf995c55c13bf0c4dd
|
3 |
+
size 39759552
|
f8.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-schnell",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [16, 56, 56],
|
13 |
+
"theta": 10000,
|
14 |
+
"qkv_bias": true,
|
15 |
+
"guidance_embed": false
|
16 |
+
},
|
17 |
+
"ae_params": {
|
18 |
+
"resolution": 256,
|
19 |
+
"in_channels": 3,
|
20 |
+
"ch": 128,
|
21 |
+
"out_ch": 3,
|
22 |
+
"ch_mult": [1, 2, 4, 4],
|
23 |
+
"num_res_blocks": 2,
|
24 |
+
"z_channels": 16,
|
25 |
+
"scale_factor": 0.3611,
|
26 |
+
"shift_factor": 0.1159
|
27 |
+
},
|
28 |
+
"ckpt_path": "flux1-schnell.safetensors",
|
29 |
+
"ae_path": "ae.safetensors",
|
30 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
31 |
+
"repo_flow": "flux1-dev.sft",
|
32 |
+
"repo_ae": "ae.sft",
|
33 |
+
"text_enc_max_length": 512,
|
34 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
35 |
+
"text_enc_device": "cuda:0",
|
36 |
+
"ae_device": "cuda:0",
|
37 |
+
"flux_device": "cuda:0",
|
38 |
+
"flow_dtype": "float16",
|
39 |
+
"ae_dtype": "bfloat16",
|
40 |
+
"text_enc_dtype": "bfloat16",
|
41 |
+
"flow_quantization_dtype": "qfloat8",
|
42 |
+
"text_enc_quantization_dtype": "qfloat8",
|
43 |
+
"compile_extras": true,
|
44 |
+
"compile_blocks": true,
|
45 |
+
"offload_text_encoder": false,
|
46 |
+
"offload_vae": false,
|
47 |
+
"offload_flow": false
|
48 |
+
}
|
float8_quantize.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import init
|
5 |
+
import math
|
6 |
+
from torch.compiler import is_compiling
|
7 |
+
from torch import __version__
|
8 |
+
from torch.version import cuda
|
9 |
+
|
10 |
+
from modules.flux_model import Modulation
|
11 |
+
|
12 |
+
IS_TORCH_2_4 = __version__ < (2, 4, 9)
|
13 |
+
LT_TORCH_2_4 = __version__ < (2, 4)
|
14 |
+
if LT_TORCH_2_4:
|
15 |
+
if not hasattr(torch, "_scaled_mm"):
|
16 |
+
raise RuntimeError(
|
17 |
+
"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later."
|
18 |
+
)
|
19 |
+
CUDA_VERSION = float(cuda) if cuda else 0
|
20 |
+
if CUDA_VERSION < 12.4:
|
21 |
+
raise RuntimeError(
|
22 |
+
f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}."
|
23 |
+
)
|
24 |
+
try:
|
25 |
+
from cublas_ops import CublasLinear
|
26 |
+
except ImportError:
|
27 |
+
CublasLinear = type(None)
|
28 |
+
|
29 |
+
|
30 |
+
class F8Linear(nn.Module):
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
in_features: int,
|
35 |
+
out_features: int,
|
36 |
+
bias: bool = True,
|
37 |
+
device=None,
|
38 |
+
dtype=torch.float16,
|
39 |
+
float8_dtype=torch.float8_e4m3fn,
|
40 |
+
float_weight: torch.Tensor = None,
|
41 |
+
float_bias: torch.Tensor = None,
|
42 |
+
num_scale_trials: int = 12,
|
43 |
+
input_float8_dtype=torch.float8_e5m2,
|
44 |
+
) -> None:
|
45 |
+
super().__init__()
|
46 |
+
self.in_features = in_features
|
47 |
+
self.out_features = out_features
|
48 |
+
self.float8_dtype = float8_dtype
|
49 |
+
self.input_float8_dtype = input_float8_dtype
|
50 |
+
self.input_scale_initialized = False
|
51 |
+
self.weight_initialized = False
|
52 |
+
self.max_value = torch.finfo(self.float8_dtype).max
|
53 |
+
self.input_max_value = torch.finfo(self.input_float8_dtype).max
|
54 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
55 |
+
if float_weight is None:
|
56 |
+
self.weight = nn.Parameter(
|
57 |
+
torch.empty((out_features, in_features), **factory_kwargs)
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
self.weight = nn.Parameter(
|
61 |
+
float_weight, requires_grad=float_weight.requires_grad
|
62 |
+
)
|
63 |
+
if float_bias is None:
|
64 |
+
if bias:
|
65 |
+
self.bias = nn.Parameter(
|
66 |
+
torch.empty(out_features, **factory_kwargs),
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
self.register_parameter("bias", None)
|
70 |
+
else:
|
71 |
+
self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad)
|
72 |
+
self.num_scale_trials = num_scale_trials
|
73 |
+
self.input_amax_trials = torch.zeros(
|
74 |
+
num_scale_trials, requires_grad=False, device=device, dtype=torch.float32
|
75 |
+
)
|
76 |
+
self.trial_index = 0
|
77 |
+
self.register_buffer("scale", None)
|
78 |
+
self.register_buffer(
|
79 |
+
"input_scale",
|
80 |
+
None,
|
81 |
+
)
|
82 |
+
self.register_buffer(
|
83 |
+
"float8_data",
|
84 |
+
None,
|
85 |
+
)
|
86 |
+
self.scale_reciprocal = self.register_buffer("scale_reciprocal", None)
|
87 |
+
self.input_scale_reciprocal = self.register_buffer(
|
88 |
+
"input_scale_reciprocal", None
|
89 |
+
)
|
90 |
+
|
91 |
+
def _load_from_state_dict(
|
92 |
+
self,
|
93 |
+
state_dict,
|
94 |
+
prefix,
|
95 |
+
local_metadata,
|
96 |
+
strict,
|
97 |
+
missing_keys,
|
98 |
+
unexpected_keys,
|
99 |
+
error_msgs,
|
100 |
+
):
|
101 |
+
sd = {k.replace(prefix, ""): v for k, v in state_dict.items()}
|
102 |
+
if "weight" in sd:
|
103 |
+
if (
|
104 |
+
"float8_data" not in sd
|
105 |
+
or sd["float8_data"] is None
|
106 |
+
and sd["weight"].shape == (self.out_features, self.in_features)
|
107 |
+
):
|
108 |
+
# Initialize as if it's an F8Linear that needs to be quantized
|
109 |
+
self._parameters["weight"] = nn.Parameter(
|
110 |
+
sd["weight"], requires_grad=False
|
111 |
+
)
|
112 |
+
if "bias" in sd:
|
113 |
+
self._parameters["bias"] = nn.Parameter(
|
114 |
+
sd["bias"], requires_grad=False
|
115 |
+
)
|
116 |
+
self.quantize_weight()
|
117 |
+
elif sd["float8_data"].shape == (
|
118 |
+
self.out_features,
|
119 |
+
self.in_features,
|
120 |
+
) and sd["weight"] == torch.zeros_like(sd["weight"]):
|
121 |
+
w = sd["weight"]
|
122 |
+
# Set the init values as if it's already quantized float8_data
|
123 |
+
self._buffers["float8_data"] = sd["float8_data"]
|
124 |
+
self._parameters["weight"] = nn.Parameter(
|
125 |
+
torch.zeros(
|
126 |
+
1,
|
127 |
+
dtype=w.dtype,
|
128 |
+
device=w.device,
|
129 |
+
requires_grad=False,
|
130 |
+
)
|
131 |
+
)
|
132 |
+
if "bias" in sd:
|
133 |
+
self._parameters["bias"] = nn.Parameter(
|
134 |
+
sd["bias"], requires_grad=False
|
135 |
+
)
|
136 |
+
self.weight_initialized = True
|
137 |
+
|
138 |
+
# Check if scales and reciprocals are initialized
|
139 |
+
if all(
|
140 |
+
key in sd
|
141 |
+
for key in [
|
142 |
+
"scale",
|
143 |
+
"input_scale",
|
144 |
+
"scale_reciprocal",
|
145 |
+
"input_scale_reciprocal",
|
146 |
+
]
|
147 |
+
):
|
148 |
+
self.scale = sd["scale"].float()
|
149 |
+
self.input_scale = sd["input_scale"].float()
|
150 |
+
self.scale_reciprocal = sd["scale_reciprocal"].float()
|
151 |
+
self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
|
152 |
+
self.input_scale_initialized = True
|
153 |
+
self.trial_index = self.num_scale_trials
|
154 |
+
elif "scale" in sd and "scale_reciprocal" in sd:
|
155 |
+
self.scale = sd["scale"].float()
|
156 |
+
self.input_scale = (
|
157 |
+
sd["input_scale"].float() if "input_scale" in sd else None
|
158 |
+
)
|
159 |
+
self.scale_reciprocal = sd["scale_reciprocal"].float()
|
160 |
+
self.input_scale_reciprocal = (
|
161 |
+
sd["input_scale_reciprocal"].float()
|
162 |
+
if "input_scale_reciprocal" in sd
|
163 |
+
else None
|
164 |
+
)
|
165 |
+
self.input_scale_initialized = (
|
166 |
+
True if "input_scale" in sd else False
|
167 |
+
)
|
168 |
+
self.trial_index = (
|
169 |
+
self.num_scale_trials if "input_scale" in sd else 0
|
170 |
+
)
|
171 |
+
self.input_amax_trials = torch.zeros(
|
172 |
+
self.num_scale_trials,
|
173 |
+
requires_grad=False,
|
174 |
+
dtype=torch.float32,
|
175 |
+
device=self.weight.device,
|
176 |
+
)
|
177 |
+
self.input_scale_initialized = False
|
178 |
+
self.trial_index = 0
|
179 |
+
else:
|
180 |
+
# If scales are not initialized, reset trials
|
181 |
+
self.input_scale_initialized = False
|
182 |
+
self.trial_index = 0
|
183 |
+
self.input_amax_trials = torch.zeros(
|
184 |
+
self.num_scale_trials, requires_grad=False, dtype=torch.float32
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise RuntimeError(
|
188 |
+
f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}"
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
raise RuntimeError(
|
192 |
+
"Weight tensor not found or has incorrect shape in state dict"
|
193 |
+
)
|
194 |
+
|
195 |
+
def quantize_weight(self):
|
196 |
+
if self.weight_initialized:
|
197 |
+
return
|
198 |
+
amax = torch.max(torch.abs(self.weight.data)).float()
|
199 |
+
self.scale = self.amax_to_scale(amax, self.max_value)
|
200 |
+
self.float8_data = self.to_fp8_saturated(
|
201 |
+
self.weight.data, self.scale, self.max_value
|
202 |
+
).to(self.float8_dtype)
|
203 |
+
self.scale_reciprocal = self.scale.reciprocal()
|
204 |
+
self.weight.data = torch.zeros(
|
205 |
+
1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
|
206 |
+
)
|
207 |
+
self.weight_initialized = True
|
208 |
+
|
209 |
+
def set_weight_tensor(self, tensor: torch.Tensor):
|
210 |
+
self.weight.data = tensor
|
211 |
+
self.weight_initialized = False
|
212 |
+
self.quantize_weight()
|
213 |
+
|
214 |
+
def amax_to_scale(self, amax, max_val):
|
215 |
+
return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val)
|
216 |
+
|
217 |
+
def to_fp8_saturated(self, x, scale, max_val):
|
218 |
+
return (x * scale).clamp(-max_val, max_val)
|
219 |
+
|
220 |
+
def quantize_input(self, x: torch.Tensor):
|
221 |
+
if self.input_scale_initialized:
|
222 |
+
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
|
223 |
+
self.input_float8_dtype
|
224 |
+
)
|
225 |
+
elif self.trial_index < self.num_scale_trials:
|
226 |
+
|
227 |
+
amax = torch.max(torch.abs(x)).float()
|
228 |
+
|
229 |
+
self.input_amax_trials[self.trial_index] = amax
|
230 |
+
self.trial_index += 1
|
231 |
+
self.input_scale = self.amax_to_scale(
|
232 |
+
self.input_amax_trials[: self.trial_index].max(), self.input_max_value
|
233 |
+
)
|
234 |
+
self.input_scale_reciprocal = self.input_scale.reciprocal()
|
235 |
+
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
|
236 |
+
self.input_float8_dtype
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
self.input_scale = self.amax_to_scale(
|
240 |
+
self.input_amax_trials.max(), self.input_max_value
|
241 |
+
)
|
242 |
+
self.input_scale_reciprocal = self.input_scale.reciprocal()
|
243 |
+
self.input_scale_initialized = True
|
244 |
+
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
|
245 |
+
self.input_float8_dtype
|
246 |
+
)
|
247 |
+
|
248 |
+
def reset_parameters(self) -> None:
|
249 |
+
if self.weight_initialized:
|
250 |
+
self.weight = nn.Parameter(
|
251 |
+
torch.empty(
|
252 |
+
(self.out_features, self.in_features),
|
253 |
+
**{
|
254 |
+
"dtype": self.weight.dtype,
|
255 |
+
"device": self.weight.device,
|
256 |
+
},
|
257 |
+
)
|
258 |
+
)
|
259 |
+
self.weight_initialized = False
|
260 |
+
self.input_scale_initialized = False
|
261 |
+
self.trial_index = 0
|
262 |
+
self.input_amax_trials.zero_()
|
263 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
264 |
+
if self.bias is not None:
|
265 |
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
266 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
267 |
+
init.uniform_(self.bias, -bound, bound)
|
268 |
+
self.quantize_weight()
|
269 |
+
self.max_value = torch.finfo(self.float8_dtype).max
|
270 |
+
self.input_max_value = torch.finfo(self.input_float8_dtype).max
|
271 |
+
|
272 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
273 |
+
if self.input_scale_initialized or is_compiling():
|
274 |
+
x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
|
275 |
+
self.input_float8_dtype
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
x = self.quantize_input(x)
|
279 |
+
|
280 |
+
prev_dims = x.shape[:-1]
|
281 |
+
x = x.view(-1, self.in_features)
|
282 |
+
|
283 |
+
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
|
284 |
+
out = torch._scaled_mm(
|
285 |
+
x,
|
286 |
+
self.float8_data.T,
|
287 |
+
scale_a=self.input_scale_reciprocal,
|
288 |
+
scale_b=self.scale_reciprocal,
|
289 |
+
bias=self.bias,
|
290 |
+
out_dtype=self.weight.dtype,
|
291 |
+
use_fast_accum=True,
|
292 |
+
)
|
293 |
+
if IS_TORCH_2_4:
|
294 |
+
out = out[0]
|
295 |
+
out = out.view(*prev_dims, self.out_features)
|
296 |
+
return out
|
297 |
+
|
298 |
+
@classmethod
|
299 |
+
def from_linear(
|
300 |
+
cls,
|
301 |
+
linear: nn.Linear,
|
302 |
+
float8_dtype=torch.float8_e4m3fn,
|
303 |
+
input_float8_dtype=torch.float8_e5m2,
|
304 |
+
) -> "F8Linear":
|
305 |
+
f8_lin = cls(
|
306 |
+
in_features=linear.in_features,
|
307 |
+
out_features=linear.out_features,
|
308 |
+
bias=linear.bias is not None,
|
309 |
+
device=linear.weight.device,
|
310 |
+
dtype=linear.weight.dtype,
|
311 |
+
float8_dtype=float8_dtype,
|
312 |
+
float_weight=linear.weight.data,
|
313 |
+
float_bias=(linear.bias.data if linear.bias is not None else None),
|
314 |
+
input_float8_dtype=input_float8_dtype,
|
315 |
+
)
|
316 |
+
f8_lin.quantize_weight()
|
317 |
+
return f8_lin
|
318 |
+
|
319 |
+
|
320 |
+
@torch.inference_mode()
|
321 |
+
def recursive_swap_linears(
|
322 |
+
model: nn.Module,
|
323 |
+
float8_dtype=torch.float8_e4m3fn,
|
324 |
+
input_float8_dtype=torch.float8_e5m2,
|
325 |
+
quantize_modulation: bool = True,
|
326 |
+
ignore_keys: list[str] = [],
|
327 |
+
) -> None:
|
328 |
+
"""
|
329 |
+
Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
|
330 |
+
|
331 |
+
This function traverses the model's structure and replaces each nn.Linear
|
332 |
+
instance with an F8Linear instance, which uses 8-bit floating point
|
333 |
+
quantization for weights. The original linear layer's weights are deleted
|
334 |
+
after conversion to save memory.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
model (nn.Module): The PyTorch model to modify.
|
338 |
+
|
339 |
+
Note:
|
340 |
+
This function modifies the model in-place. After calling this function,
|
341 |
+
all linear layers in the model will be using 8-bit quantization.
|
342 |
+
"""
|
343 |
+
for name, child in model.named_children():
|
344 |
+
if name in ignore_keys:
|
345 |
+
continue
|
346 |
+
if isinstance(child, Modulation) and not quantize_modulation:
|
347 |
+
continue
|
348 |
+
if isinstance(child, nn.Linear) and not isinstance(
|
349 |
+
child, (F8Linear, CublasLinear)
|
350 |
+
):
|
351 |
+
|
352 |
+
setattr(
|
353 |
+
model,
|
354 |
+
name,
|
355 |
+
F8Linear.from_linear(
|
356 |
+
child,
|
357 |
+
float8_dtype=float8_dtype,
|
358 |
+
input_float8_dtype=input_float8_dtype,
|
359 |
+
),
|
360 |
+
)
|
361 |
+
del child
|
362 |
+
else:
|
363 |
+
recursive_swap_linears(
|
364 |
+
child,
|
365 |
+
float8_dtype=float8_dtype,
|
366 |
+
input_float8_dtype=input_float8_dtype,
|
367 |
+
quantize_modulation=quantize_modulation,
|
368 |
+
ignore_keys=ignore_keys,
|
369 |
+
)
|
370 |
+
|
371 |
+
|
372 |
+
@torch.inference_mode()
|
373 |
+
def swap_to_cublaslinear(model: nn.Module):
|
374 |
+
if CublasLinear == type(None):
|
375 |
+
return
|
376 |
+
for name, child in model.named_children():
|
377 |
+
if isinstance(child, nn.Linear) and not isinstance(
|
378 |
+
child, (F8Linear, CublasLinear)
|
379 |
+
):
|
380 |
+
cublas_lin = CublasLinear(
|
381 |
+
child.in_features,
|
382 |
+
child.out_features,
|
383 |
+
bias=child.bias is not None,
|
384 |
+
dtype=child.weight.dtype,
|
385 |
+
device=child.weight.device,
|
386 |
+
)
|
387 |
+
cublas_lin.weight.data = child.weight.clone().detach()
|
388 |
+
cublas_lin.bias.data = child.bias.clone().detach()
|
389 |
+
setattr(model, name, cublas_lin)
|
390 |
+
del child
|
391 |
+
else:
|
392 |
+
swap_to_cublaslinear(child)
|
393 |
+
|
394 |
+
|
395 |
+
@torch.inference_mode()
|
396 |
+
def quantize_flow_transformer_and_dispatch_float8(
|
397 |
+
flow_model: nn.Module,
|
398 |
+
device=torch.device("cuda"),
|
399 |
+
float8_dtype=torch.float8_e4m3fn,
|
400 |
+
input_float8_dtype=torch.float8_e5m2,
|
401 |
+
offload_flow=False,
|
402 |
+
swap_linears_with_cublaslinear=True,
|
403 |
+
flow_dtype=torch.float16,
|
404 |
+
quantize_modulation: bool = True,
|
405 |
+
quantize_flow_embedder_layers: bool = True,
|
406 |
+
) -> nn.Module:
|
407 |
+
"""
|
408 |
+
Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
|
409 |
+
|
410 |
+
Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes.
|
411 |
+
|
412 |
+
Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
|
413 |
+
|
414 |
+
After dispatching, if offload_flow is True, offloads the model to cpu.
|
415 |
+
|
416 |
+
if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs.
|
417 |
+
Otherwise will skip the cublaslinear swap.
|
418 |
+
|
419 |
+
For added extra precision, you can set quantize_flow_embedder_layers to False,
|
420 |
+
this helps maintain the output quality of the flow transformer moreso than fully quantizing,
|
421 |
+
at the expense of ~512MB more VRAM usage.
|
422 |
+
|
423 |
+
For added extra precision, you can set quantize_modulation to False,
|
424 |
+
this helps maintain the output quality of the flow transformer moreso than fully quantizing,
|
425 |
+
at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers.
|
426 |
+
"""
|
427 |
+
for module in flow_model.double_blocks:
|
428 |
+
module.to(device)
|
429 |
+
module.eval()
|
430 |
+
recursive_swap_linears(
|
431 |
+
module,
|
432 |
+
float8_dtype=float8_dtype,
|
433 |
+
input_float8_dtype=input_float8_dtype,
|
434 |
+
quantize_modulation=quantize_modulation,
|
435 |
+
)
|
436 |
+
torch.cuda.empty_cache()
|
437 |
+
for module in flow_model.single_blocks:
|
438 |
+
module.to(device)
|
439 |
+
module.eval()
|
440 |
+
recursive_swap_linears(
|
441 |
+
module,
|
442 |
+
float8_dtype=float8_dtype,
|
443 |
+
input_float8_dtype=input_float8_dtype,
|
444 |
+
quantize_modulation=quantize_modulation,
|
445 |
+
)
|
446 |
+
torch.cuda.empty_cache()
|
447 |
+
to_gpu_extras = [
|
448 |
+
"vector_in",
|
449 |
+
"img_in",
|
450 |
+
"txt_in",
|
451 |
+
"time_in",
|
452 |
+
"guidance_in",
|
453 |
+
"final_layer",
|
454 |
+
"pe_embedder",
|
455 |
+
]
|
456 |
+
for module in to_gpu_extras:
|
457 |
+
m_extra = getattr(flow_model, module)
|
458 |
+
if m_extra is None:
|
459 |
+
continue
|
460 |
+
m_extra.to(device)
|
461 |
+
m_extra.eval()
|
462 |
+
if isinstance(m_extra, nn.Linear) and not isinstance(
|
463 |
+
m_extra, (F8Linear, CublasLinear)
|
464 |
+
):
|
465 |
+
if quantize_flow_embedder_layers:
|
466 |
+
setattr(
|
467 |
+
flow_model,
|
468 |
+
module,
|
469 |
+
F8Linear.from_linear(
|
470 |
+
m_extra,
|
471 |
+
float8_dtype=float8_dtype,
|
472 |
+
input_float8_dtype=input_float8_dtype,
|
473 |
+
),
|
474 |
+
)
|
475 |
+
del m_extra
|
476 |
+
elif module != "final_layer":
|
477 |
+
if quantize_flow_embedder_layers:
|
478 |
+
recursive_swap_linears(
|
479 |
+
m_extra,
|
480 |
+
float8_dtype=float8_dtype,
|
481 |
+
input_float8_dtype=input_float8_dtype,
|
482 |
+
quantize_modulation=quantize_modulation,
|
483 |
+
)
|
484 |
+
torch.cuda.empty_cache()
|
485 |
+
if (
|
486 |
+
swap_linears_with_cublaslinear
|
487 |
+
and flow_dtype == torch.float16
|
488 |
+
and CublasLinear != type(None)
|
489 |
+
):
|
490 |
+
swap_to_cublaslinear(flow_model)
|
491 |
+
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
492 |
+
logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
|
493 |
+
if offload_flow:
|
494 |
+
flow_model.to("cpu")
|
495 |
+
torch.cuda.empty_cache()
|
496 |
+
return flow_model
|
flux_emphasis.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING, Optional
|
2 |
+
from pydash import flatten
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers.models.clip.tokenization_clip import CLIPTokenizer
|
6 |
+
from einops import repeat
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from flux_pipeline import FluxPipeline
|
10 |
+
|
11 |
+
|
12 |
+
def parse_prompt_attention(text):
|
13 |
+
"""
|
14 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
15 |
+
Accepted tokens are:
|
16 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
17 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
18 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
19 |
+
\\( - literal character '('
|
20 |
+
\\[ - literal character '['
|
21 |
+
\\) - literal character ')'
|
22 |
+
\\] - literal character ']'
|
23 |
+
\\ - literal character '\'
|
24 |
+
anything else - just text
|
25 |
+
|
26 |
+
>>> parse_prompt_attention('normal text')
|
27 |
+
[['normal text', 1.0]]
|
28 |
+
>>> parse_prompt_attention('an (important) word')
|
29 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
30 |
+
>>> parse_prompt_attention('(unbalanced')
|
31 |
+
[['unbalanced', 1.1]]
|
32 |
+
>>> parse_prompt_attention('\\(literal\\]')
|
33 |
+
[['(literal]', 1.0]]
|
34 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
35 |
+
[['unnecessaryparens', 1.1]]
|
36 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
37 |
+
[['a ', 1.0],
|
38 |
+
['house', 1.5730000000000004],
|
39 |
+
[' ', 1.1],
|
40 |
+
['on', 1.0],
|
41 |
+
[' a ', 1.1],
|
42 |
+
['hill', 0.55],
|
43 |
+
[', sun, ', 1.1],
|
44 |
+
['sky', 1.4641000000000006],
|
45 |
+
['.', 1.1]]
|
46 |
+
"""
|
47 |
+
import re
|
48 |
+
|
49 |
+
re_attention = re.compile(
|
50 |
+
r"""
|
51 |
+
\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
|
52 |
+
\)|]|[^\\()\[\]:]+|:
|
53 |
+
""",
|
54 |
+
re.X,
|
55 |
+
)
|
56 |
+
|
57 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
58 |
+
|
59 |
+
res = []
|
60 |
+
round_brackets = []
|
61 |
+
square_brackets = []
|
62 |
+
|
63 |
+
round_bracket_multiplier = 1.1
|
64 |
+
square_bracket_multiplier = 1 / 1.1
|
65 |
+
|
66 |
+
def multiply_range(start_position, multiplier):
|
67 |
+
for p in range(start_position, len(res)):
|
68 |
+
res[p][1] *= multiplier
|
69 |
+
|
70 |
+
for m in re_attention.finditer(text):
|
71 |
+
text = m.group(0)
|
72 |
+
weight = m.group(1)
|
73 |
+
|
74 |
+
if text.startswith("\\"):
|
75 |
+
res.append([text[1:], 1.0])
|
76 |
+
elif text == "(":
|
77 |
+
round_brackets.append(len(res))
|
78 |
+
elif text == "[":
|
79 |
+
square_brackets.append(len(res))
|
80 |
+
elif weight is not None and len(round_brackets) > 0:
|
81 |
+
multiply_range(round_brackets.pop(), float(weight))
|
82 |
+
elif text == ")" and len(round_brackets) > 0:
|
83 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
84 |
+
elif text == "]" and len(square_brackets) > 0:
|
85 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
86 |
+
else:
|
87 |
+
parts = re.split(re_break, text)
|
88 |
+
for i, part in enumerate(parts):
|
89 |
+
if i > 0:
|
90 |
+
res.append(["BREAK", -1])
|
91 |
+
res.append([part, 1.0])
|
92 |
+
|
93 |
+
for pos in round_brackets:
|
94 |
+
multiply_range(pos, round_bracket_multiplier)
|
95 |
+
|
96 |
+
for pos in square_brackets:
|
97 |
+
multiply_range(pos, square_bracket_multiplier)
|
98 |
+
|
99 |
+
if len(res) == 0:
|
100 |
+
res = [["", 1.0]]
|
101 |
+
|
102 |
+
# merge runs of identical weights
|
103 |
+
i = 0
|
104 |
+
while i + 1 < len(res):
|
105 |
+
if res[i][1] == res[i + 1][1]:
|
106 |
+
res[i][0] += res[i + 1][0]
|
107 |
+
res.pop(i + 1)
|
108 |
+
else:
|
109 |
+
i += 1
|
110 |
+
|
111 |
+
return res
|
112 |
+
|
113 |
+
|
114 |
+
def get_prompts_tokens_with_weights(
|
115 |
+
clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
119 |
+
|
120 |
+
Args:
|
121 |
+
pipe (CLIPTokenizer)
|
122 |
+
A CLIPTokenizer
|
123 |
+
prompt (str)
|
124 |
+
A prompt string with weights
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
text_tokens (list)
|
128 |
+
A list contains token ids
|
129 |
+
text_weight (list)
|
130 |
+
A list contains the correspodent weight of token ids
|
131 |
+
|
132 |
+
Example:
|
133 |
+
import torch
|
134 |
+
from transformers import CLIPTokenizer
|
135 |
+
|
136 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained(
|
137 |
+
"stablediffusionapi/deliberate-v2"
|
138 |
+
, subfolder = "tokenizer"
|
139 |
+
, dtype = torch.float16
|
140 |
+
)
|
141 |
+
|
142 |
+
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
|
143 |
+
clip_tokenizer = clip_tokenizer
|
144 |
+
,prompt = "a (red:1.5) cat"*70
|
145 |
+
)
|
146 |
+
"""
|
147 |
+
texts_and_weights = parse_prompt_attention(prompt)
|
148 |
+
text_tokens, text_weights = [], []
|
149 |
+
maxlen = clip_tokenizer.model_max_length
|
150 |
+
for word, weight in texts_and_weights:
|
151 |
+
# tokenize and discard the starting and the ending token
|
152 |
+
token = clip_tokenizer(
|
153 |
+
word, truncation=False, padding=False, add_special_tokens=False
|
154 |
+
).input_ids
|
155 |
+
# so that tokenize whatever length prompt
|
156 |
+
# the returned token is a 1d list: [320, 1125, 539, 320]
|
157 |
+
if debug:
|
158 |
+
print(
|
159 |
+
token,
|
160 |
+
"|FOR MODEL LEN{}|".format(maxlen),
|
161 |
+
clip_tokenizer.decode(
|
162 |
+
token, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
163 |
+
),
|
164 |
+
)
|
165 |
+
# merge the new tokens to the all tokens holder: text_tokens
|
166 |
+
text_tokens = [*text_tokens, *token]
|
167 |
+
|
168 |
+
# each token chunk will come with one weight, like ['red cat', 2.0]
|
169 |
+
# need to expand weight for each token.
|
170 |
+
chunk_weights = [weight] * len(token)
|
171 |
+
|
172 |
+
# append the weight back to the weight holder: text_weights
|
173 |
+
text_weights = [*text_weights, *chunk_weights]
|
174 |
+
return text_tokens, text_weights
|
175 |
+
|
176 |
+
|
177 |
+
def group_tokens_and_weights(
|
178 |
+
token_ids: list,
|
179 |
+
weights: list,
|
180 |
+
pad_last_block=False,
|
181 |
+
bos=49406,
|
182 |
+
eos=49407,
|
183 |
+
max_length=77,
|
184 |
+
pad_tokens=True,
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
Produce tokens and weights in groups and pad the missing tokens
|
188 |
+
|
189 |
+
Args:
|
190 |
+
token_ids (list)
|
191 |
+
The token ids from tokenizer
|
192 |
+
weights (list)
|
193 |
+
The weights list from function get_prompts_tokens_with_weights
|
194 |
+
pad_last_block (bool)
|
195 |
+
Control if fill the last token list to 75 tokens with eos
|
196 |
+
Returns:
|
197 |
+
new_token_ids (2d list)
|
198 |
+
new_weights (2d list)
|
199 |
+
|
200 |
+
Example:
|
201 |
+
token_groups,weight_groups = group_tokens_and_weights(
|
202 |
+
token_ids = token_id_list
|
203 |
+
, weights = token_weight_list
|
204 |
+
)
|
205 |
+
"""
|
206 |
+
# TODO: Possibly need to fix this, since this doesn't seem correct.
|
207 |
+
# Ignoring for now since I don't know what the consequences might be
|
208 |
+
# if changed to <= instead of <.
|
209 |
+
max_len = max_length - 2 if max_length < 77 else max_length
|
210 |
+
# this will be a 2d list
|
211 |
+
new_token_ids = []
|
212 |
+
new_weights = []
|
213 |
+
while len(token_ids) >= max_len:
|
214 |
+
# get the first 75 tokens
|
215 |
+
temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)]
|
216 |
+
temp_77_weights = [weights.pop(0) for _ in range(max_len)]
|
217 |
+
|
218 |
+
# extract token ids and weights
|
219 |
+
|
220 |
+
if pad_tokens:
|
221 |
+
if bos is not None:
|
222 |
+
temp_77_token_ids = [bos] + temp_77_token_ids + [eos]
|
223 |
+
temp_77_weights = [1.0] + temp_77_weights + [1.0]
|
224 |
+
else:
|
225 |
+
temp_77_token_ids = temp_77_token_ids + [eos]
|
226 |
+
temp_77_weights = temp_77_weights + [1.0]
|
227 |
+
|
228 |
+
# add 77 token and weights chunk to the holder list
|
229 |
+
new_token_ids.append(temp_77_token_ids)
|
230 |
+
new_weights.append(temp_77_weights)
|
231 |
+
|
232 |
+
# padding the left
|
233 |
+
if len(token_ids) > 0:
|
234 |
+
if pad_tokens:
|
235 |
+
padding_len = max_len - len(token_ids) if pad_last_block else 0
|
236 |
+
|
237 |
+
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
|
238 |
+
new_token_ids.append(temp_77_token_ids)
|
239 |
+
|
240 |
+
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
|
241 |
+
new_weights.append(temp_77_weights)
|
242 |
+
else:
|
243 |
+
new_token_ids.append(token_ids)
|
244 |
+
new_weights.append(weights)
|
245 |
+
return new_token_ids, new_weights
|
246 |
+
|
247 |
+
|
248 |
+
def standardize_tensor(
|
249 |
+
input_tensor: torch.Tensor, target_mean: float, target_std: float
|
250 |
+
) -> torch.Tensor:
|
251 |
+
"""
|
252 |
+
This function standardizes an input tensor so that it has a specific mean and standard deviation.
|
253 |
+
|
254 |
+
Parameters:
|
255 |
+
input_tensor (torch.Tensor): The tensor to standardize.
|
256 |
+
target_mean (float): The target mean for the tensor.
|
257 |
+
target_std (float): The target standard deviation for the tensor.
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
torch.Tensor: The standardized tensor.
|
261 |
+
"""
|
262 |
+
|
263 |
+
# First, compute the mean and std of the input tensor
|
264 |
+
mean = input_tensor.mean()
|
265 |
+
std = input_tensor.std()
|
266 |
+
|
267 |
+
# Then, standardize the tensor to have a mean of 0 and std of 1
|
268 |
+
standardized_tensor = (input_tensor - mean) / std
|
269 |
+
|
270 |
+
# Finally, scale the tensor to the target mean and std
|
271 |
+
output_tensor = standardized_tensor * target_std + target_mean
|
272 |
+
|
273 |
+
return output_tensor
|
274 |
+
|
275 |
+
|
276 |
+
def apply_weights(
|
277 |
+
prompt_tokens: torch.Tensor,
|
278 |
+
weight_tensor: torch.Tensor,
|
279 |
+
token_embedding: torch.Tensor,
|
280 |
+
eos_token_id: int,
|
281 |
+
pad_last_block: bool = True,
|
282 |
+
) -> torch.FloatTensor:
|
283 |
+
mean = token_embedding.mean()
|
284 |
+
std = token_embedding.std()
|
285 |
+
if pad_last_block:
|
286 |
+
pooled_tensor = token_embedding[
|
287 |
+
torch.arange(token_embedding.shape[0], device=token_embedding.device),
|
288 |
+
(
|
289 |
+
prompt_tokens.to(dtype=torch.int, device=token_embedding.device)
|
290 |
+
== eos_token_id
|
291 |
+
)
|
292 |
+
.int()
|
293 |
+
.argmax(dim=-1),
|
294 |
+
]
|
295 |
+
else:
|
296 |
+
pooled_tensor = token_embedding[:, -1]
|
297 |
+
|
298 |
+
for j in range(len(weight_tensor)):
|
299 |
+
if weight_tensor[j] != 1.0:
|
300 |
+
token_embedding[:, j] = (
|
301 |
+
pooled_tensor
|
302 |
+
+ (token_embedding[:, j] - pooled_tensor) * weight_tensor[j]
|
303 |
+
)
|
304 |
+
return standardize_tensor(token_embedding, mean, std)
|
305 |
+
|
306 |
+
|
307 |
+
@torch.inference_mode()
|
308 |
+
def get_weighted_text_embeddings_flux(
|
309 |
+
pipe: "FluxPipeline",
|
310 |
+
prompt: str = "",
|
311 |
+
num_images_per_prompt: int = 1,
|
312 |
+
device: Optional[torch.device] = None,
|
313 |
+
target_device: Optional[torch.device] = torch.device("cuda:0"),
|
314 |
+
target_dtype: Optional[torch.dtype] = torch.bfloat16,
|
315 |
+
debug: bool = False,
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
This function can process long prompt with weights, no length limitation
|
319 |
+
for Stable Diffusion XL
|
320 |
+
|
321 |
+
Args:
|
322 |
+
pipe (StableDiffusionPipeline)
|
323 |
+
prompt (str)
|
324 |
+
prompt_2 (str)
|
325 |
+
neg_prompt (str)
|
326 |
+
neg_prompt_2 (str)
|
327 |
+
num_images_per_prompt (int)
|
328 |
+
device (torch.device)
|
329 |
+
Returns:
|
330 |
+
prompt_embeds (torch.Tensor)
|
331 |
+
neg_prompt_embeds (torch.Tensor)
|
332 |
+
"""
|
333 |
+
device = device or pipe._execution_device
|
334 |
+
|
335 |
+
eos = pipe.clip.tokenizer.eos_token_id
|
336 |
+
eos_2 = pipe.t5.tokenizer.eos_token_id
|
337 |
+
bos = pipe.clip.tokenizer.bos_token_id
|
338 |
+
bos_2 = pipe.t5.tokenizer.bos_token_id
|
339 |
+
|
340 |
+
clip = pipe.clip.hf_module
|
341 |
+
t5 = pipe.t5.hf_module
|
342 |
+
|
343 |
+
tokenizer_clip = pipe.clip.tokenizer
|
344 |
+
tokenizer_t5 = pipe.t5.tokenizer
|
345 |
+
|
346 |
+
t5_length = 512 if pipe.name == "flux-dev" else 256
|
347 |
+
clip_length = 77
|
348 |
+
|
349 |
+
# tokenizer 1
|
350 |
+
prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
|
351 |
+
tokenizer_clip, prompt, debug=debug
|
352 |
+
)
|
353 |
+
|
354 |
+
# tokenizer 2
|
355 |
+
prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
|
356 |
+
tokenizer_t5, prompt, debug=debug
|
357 |
+
)
|
358 |
+
|
359 |
+
prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
|
360 |
+
prompt_tokens_clip,
|
361 |
+
prompt_weights_clip,
|
362 |
+
pad_last_block=True,
|
363 |
+
bos=bos,
|
364 |
+
eos=eos,
|
365 |
+
max_length=clip_length,
|
366 |
+
)
|
367 |
+
prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights(
|
368 |
+
prompt_tokens_t5,
|
369 |
+
prompt_weights_t5,
|
370 |
+
pad_last_block=True,
|
371 |
+
bos=bos_2,
|
372 |
+
eos=eos_2,
|
373 |
+
max_length=t5_length,
|
374 |
+
pad_tokens=False,
|
375 |
+
)
|
376 |
+
prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped)
|
377 |
+
prompt_weights_t5 = flatten(prompt_weights_t5_grouped)
|
378 |
+
prompt_tokens_clip = flatten(prompt_tokens_clip_grouped)
|
379 |
+
prompt_weights_clip = flatten(prompt_weights_clip_grouped)
|
380 |
+
|
381 |
+
prompt_tokens_clip = tokenizer_clip.decode(
|
382 |
+
prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
383 |
+
)
|
384 |
+
prompt_tokens_clip = tokenizer_clip(
|
385 |
+
prompt_tokens_clip,
|
386 |
+
add_special_tokens=True,
|
387 |
+
padding="max_length",
|
388 |
+
truncation=True,
|
389 |
+
max_length=clip_length,
|
390 |
+
return_tensors="pt",
|
391 |
+
).input_ids.to(device)
|
392 |
+
prompt_tokens_t5 = tokenizer_t5.decode(
|
393 |
+
prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
394 |
+
)
|
395 |
+
prompt_tokens_t5 = tokenizer_t5(
|
396 |
+
prompt_tokens_t5,
|
397 |
+
add_special_tokens=True,
|
398 |
+
padding="max_length",
|
399 |
+
truncation=True,
|
400 |
+
max_length=t5_length,
|
401 |
+
return_tensors="pt",
|
402 |
+
).input_ids.to(device)
|
403 |
+
|
404 |
+
prompt_weights_t5 = torch.cat(
|
405 |
+
[
|
406 |
+
torch.tensor(prompt_weights_t5, dtype=torch.float32),
|
407 |
+
torch.full(
|
408 |
+
(t5_length - torch.tensor(prompt_weights_t5).numel(),),
|
409 |
+
1.0,
|
410 |
+
dtype=torch.float32,
|
411 |
+
),
|
412 |
+
],
|
413 |
+
dim=0,
|
414 |
+
).to(device)
|
415 |
+
|
416 |
+
clip_embeds = clip(
|
417 |
+
prompt_tokens_clip, output_hidden_states=True, attention_mask=None
|
418 |
+
)["pooler_output"]
|
419 |
+
if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1:
|
420 |
+
clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
|
421 |
+
|
422 |
+
weight_tensor_t5 = torch.tensor(
|
423 |
+
flatten(prompt_weights_t5), dtype=torch.float32, device=device
|
424 |
+
)
|
425 |
+
t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[
|
426 |
+
"last_hidden_state"
|
427 |
+
]
|
428 |
+
t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
|
429 |
+
if debug:
|
430 |
+
print(t5_embeds.shape)
|
431 |
+
if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
|
432 |
+
t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
|
433 |
+
txt_ids = torch.zeros(
|
434 |
+
num_images_per_prompt,
|
435 |
+
t5_embeds.shape[1],
|
436 |
+
3,
|
437 |
+
device=target_device,
|
438 |
+
dtype=target_dtype,
|
439 |
+
)
|
440 |
+
t5_embeds = t5_embeds.to(target_device, dtype=target_dtype)
|
441 |
+
clip_embeds = clip_embeds.to(target_device, dtype=target_dtype)
|
442 |
+
|
443 |
+
return (
|
444 |
+
clip_embeds,
|
445 |
+
t5_embeds,
|
446 |
+
txt_ids,
|
447 |
+
)
|
flux_pipeline.py
ADDED
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import warnings
|
5 |
+
from typing import TYPE_CHECKING, Callable, List, Optional, OrderedDict, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
11 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
12 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
13 |
+
import torch
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
from flux_emphasis import get_weighted_text_embeddings_flux
|
17 |
+
|
18 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
19 |
+
torch.backends.cudnn.allow_tf32 = True
|
20 |
+
torch.backends.cudnn.benchmark = True
|
21 |
+
torch.backends.cudnn.benchmark_limit = 20
|
22 |
+
torch.set_float32_matmul_precision("high")
|
23 |
+
from pybase64 import standard_b64decode
|
24 |
+
from torch._dynamo import config
|
25 |
+
from torch._inductor import config as ind_config
|
26 |
+
|
27 |
+
config.cache_size_limit = 10000000000
|
28 |
+
ind_config.shape_padding = True
|
29 |
+
import platform
|
30 |
+
|
31 |
+
from loguru import logger
|
32 |
+
from torchvision.transforms import functional as TF
|
33 |
+
from tqdm import tqdm
|
34 |
+
|
35 |
+
import lora_loading
|
36 |
+
from image_encoder import ImageEncoder
|
37 |
+
from util import (
|
38 |
+
ModelSpec,
|
39 |
+
ModelVersion,
|
40 |
+
into_device,
|
41 |
+
into_dtype,
|
42 |
+
load_config_from_path,
|
43 |
+
load_models_from_config,
|
44 |
+
)
|
45 |
+
|
46 |
+
if platform.system() == "Windows":
|
47 |
+
MAX_RAND = 2**16 - 1
|
48 |
+
else:
|
49 |
+
MAX_RAND = 2**32 - 1
|
50 |
+
|
51 |
+
|
52 |
+
if TYPE_CHECKING:
|
53 |
+
from modules.autoencoder import AutoEncoder
|
54 |
+
from modules.conditioner import HFEmbedder
|
55 |
+
from modules.flux_model import Flux
|
56 |
+
|
57 |
+
|
58 |
+
class FluxPipeline:
|
59 |
+
"""
|
60 |
+
FluxPipeline is a class that provides a pipeline for generating images using the Flux model.
|
61 |
+
It handles input preparation, timestep generation, noise generation, device management
|
62 |
+
and model compilation.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
name: str,
|
68 |
+
offload: bool = False,
|
69 |
+
clip: "HFEmbedder" = None,
|
70 |
+
t5: "HFEmbedder" = None,
|
71 |
+
model: "Flux" = None,
|
72 |
+
ae: "AutoEncoder" = None,
|
73 |
+
dtype: torch.dtype = torch.float16,
|
74 |
+
verbose: bool = False,
|
75 |
+
flux_device: torch.device | str = "cuda:0",
|
76 |
+
ae_device: torch.device | str = "cuda:1",
|
77 |
+
clip_device: torch.device | str = "cuda:1",
|
78 |
+
t5_device: torch.device | str = "cuda:1",
|
79 |
+
config: ModelSpec = None,
|
80 |
+
debug: bool = False,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Initialize the FluxPipeline class.
|
84 |
+
|
85 |
+
This class is responsible for preparing input tensors for the Flux model, generating
|
86 |
+
timesteps and noise, and handling device management for model offloading.
|
87 |
+
"""
|
88 |
+
|
89 |
+
if config is None:
|
90 |
+
raise ValueError("ModelSpec config is required!")
|
91 |
+
|
92 |
+
self.debug = debug
|
93 |
+
self.name = name
|
94 |
+
self.device_flux = into_device(flux_device)
|
95 |
+
self.device_ae = into_device(ae_device)
|
96 |
+
self.device_clip = into_device(clip_device)
|
97 |
+
self.device_t5 = into_device(t5_device)
|
98 |
+
self.dtype = into_dtype(dtype)
|
99 |
+
self.offload = offload
|
100 |
+
self.clip: "HFEmbedder" = clip
|
101 |
+
self.t5: "HFEmbedder" = t5
|
102 |
+
self.model: "Flux" = model
|
103 |
+
self.ae: "AutoEncoder" = ae
|
104 |
+
self.rng = torch.Generator(device="cpu")
|
105 |
+
self.img_encoder = ImageEncoder()
|
106 |
+
self.verbose = verbose
|
107 |
+
self.ae_dtype = torch.bfloat16
|
108 |
+
self.config = config
|
109 |
+
self.offload_text_encoder = config.offload_text_encoder
|
110 |
+
self.offload_vae = config.offload_vae
|
111 |
+
self.offload_flow = config.offload_flow
|
112 |
+
# If models are not offloaded, move them to the appropriate devices
|
113 |
+
|
114 |
+
if not self.offload_flow:
|
115 |
+
self.model.to(self.device_flux)
|
116 |
+
if not self.offload_vae:
|
117 |
+
self.ae.to(self.device_ae)
|
118 |
+
if not self.offload_text_encoder:
|
119 |
+
self.clip.to(self.device_clip)
|
120 |
+
self.t5.to(self.device_t5)
|
121 |
+
|
122 |
+
# compile the model if needed
|
123 |
+
if config.compile_blocks or config.compile_extras:
|
124 |
+
self.compile()
|
125 |
+
|
126 |
+
def set_seed(
|
127 |
+
self, seed: int | None = None, seed_globally: bool = False
|
128 |
+
) -> torch.Generator:
|
129 |
+
if isinstance(seed, (int, float)):
|
130 |
+
seed = int(abs(seed)) % MAX_RAND
|
131 |
+
cuda_generator = torch.Generator("cuda").manual_seed(seed)
|
132 |
+
elif isinstance(seed, str):
|
133 |
+
try:
|
134 |
+
seed = abs(int(seed)) % MAX_RAND
|
135 |
+
except Exception as e:
|
136 |
+
logger.warning(
|
137 |
+
f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
|
138 |
+
)
|
139 |
+
seed = abs(self.rng.seed()) % MAX_RAND
|
140 |
+
cuda_generator = torch.Generator("cuda").manual_seed(seed)
|
141 |
+
else:
|
142 |
+
seed = abs(self.rng.seed()) % MAX_RAND
|
143 |
+
cuda_generator = torch.Generator("cuda").manual_seed(seed)
|
144 |
+
|
145 |
+
if seed_globally:
|
146 |
+
torch.cuda.manual_seed_all(seed)
|
147 |
+
np.random.seed(seed)
|
148 |
+
random.seed(seed)
|
149 |
+
return cuda_generator, seed
|
150 |
+
|
151 |
+
def load_lora(
|
152 |
+
self,
|
153 |
+
lora_path: Union[str, OrderedDict[str, torch.Tensor]],
|
154 |
+
scale: float,
|
155 |
+
name: Optional[str] = None,
|
156 |
+
):
|
157 |
+
"""
|
158 |
+
Loads a LoRA checkpoint into the Flux flow transformer.
|
159 |
+
|
160 |
+
Currently supports LoRA checkpoints from either diffusers checkpoints which usually start with transformer.[...],
|
161 |
+
or loras which contain keys which start with lora_unet_[...].
|
162 |
+
|
163 |
+
Args:
|
164 |
+
lora_path (str | OrderedDict[str, torch.Tensor]): Path to the LoRA checkpoint or an ordered dictionary containing the LoRA weights.
|
165 |
+
scale (float): Scaling factor for the LoRA weights.
|
166 |
+
name (str): Name of the LoRA checkpoint, optionally can be left as None, since it only acts as an identifier.
|
167 |
+
"""
|
168 |
+
self.model.load_lora(path=lora_path, scale=scale, name=name)
|
169 |
+
|
170 |
+
def unload_lora(self, path_or_identifier: str):
|
171 |
+
"""
|
172 |
+
Unloads the LoRA checkpoint from the Flux flow transformer.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded.
|
176 |
+
"""
|
177 |
+
self.model.unload_lora(path_or_identifier=path_or_identifier)
|
178 |
+
|
179 |
+
@torch.inference_mode()
|
180 |
+
def compile(self):
|
181 |
+
"""
|
182 |
+
Compiles the model and extras.
|
183 |
+
|
184 |
+
First, if:
|
185 |
+
|
186 |
+
- A) Checkpoint which already has float8 quantized weights and tuned input scales.
|
187 |
+
In which case, it will not run warmups since it assumes the input scales are already tuned.
|
188 |
+
|
189 |
+
- B) Checkpoint which has not been quantized, in which case it will be quantized
|
190 |
+
and the input scales will be tuned. via running a warmup loop.
|
191 |
+
- If the model is flux-schnell, it will run 3 warmup loops since each loop is 4 steps.
|
192 |
+
- If the model is flux-dev, it will run 1 warmup loop for 12 steps.
|
193 |
+
|
194 |
+
"""
|
195 |
+
|
196 |
+
# Run warmups if the checkpoint is not prequantized
|
197 |
+
if not self.config.prequantized_flow:
|
198 |
+
logger.info("Running warmups for compile...")
|
199 |
+
warmup_dict = dict(
|
200 |
+
prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
|
201 |
+
height=768,
|
202 |
+
width=768,
|
203 |
+
num_steps=12,
|
204 |
+
guidance=3.5,
|
205 |
+
seed=10,
|
206 |
+
)
|
207 |
+
if self.config.version == ModelVersion.flux_schnell:
|
208 |
+
warmup_dict["num_steps"] = 4
|
209 |
+
for _ in range(3):
|
210 |
+
self.generate(**warmup_dict)
|
211 |
+
else:
|
212 |
+
self.generate(**warmup_dict)
|
213 |
+
|
214 |
+
# Compile the model and extras
|
215 |
+
to_gpu_extras = [
|
216 |
+
"vector_in",
|
217 |
+
"img_in",
|
218 |
+
"txt_in",
|
219 |
+
"time_in",
|
220 |
+
"guidance_in",
|
221 |
+
"final_layer",
|
222 |
+
"pe_embedder",
|
223 |
+
]
|
224 |
+
if self.config.compile_blocks:
|
225 |
+
for block in self.model.double_blocks:
|
226 |
+
block.compile()
|
227 |
+
for block in self.model.single_blocks:
|
228 |
+
block.compile()
|
229 |
+
if self.config.compile_extras:
|
230 |
+
for extra in to_gpu_extras:
|
231 |
+
getattr(self.model, extra).compile()
|
232 |
+
|
233 |
+
@torch.inference_mode()
|
234 |
+
def prepare(
|
235 |
+
self,
|
236 |
+
img: torch.Tensor,
|
237 |
+
prompt: str | list[str],
|
238 |
+
target_device: torch.device = torch.device("cuda:0"),
|
239 |
+
target_dtype: torch.dtype = torch.float16,
|
240 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
241 |
+
"""
|
242 |
+
Prepare input tensors for the Flux model.
|
243 |
+
|
244 |
+
This function processes the input image and text prompt, converting them into
|
245 |
+
the appropriate format and embedding representations required by the model.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
img (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
|
249 |
+
prompt (str | list[str]): Text prompt or list of prompts guiding the image generation.
|
250 |
+
target_device (torch.device, optional): The target device for the output tensors.
|
251 |
+
Defaults to torch.device("cuda:0").
|
252 |
+
target_dtype (torch.dtype, optional): The target data type for the output tensors.
|
253 |
+
Defaults to torch.float16.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
|
257 |
+
- img: Processed image tensor.
|
258 |
+
- img_ids: Image position IDs.
|
259 |
+
- vec: Clip text embedding vector.
|
260 |
+
- txt: T5 text embedding hidden states.
|
261 |
+
- txt_ids: Text position IDs.
|
262 |
+
|
263 |
+
Note:
|
264 |
+
This function handles the necessary device management for text encoder offloading
|
265 |
+
if enabled in the configuration.
|
266 |
+
"""
|
267 |
+
bs, c, h, w = img.shape
|
268 |
+
if bs == 1 and not isinstance(prompt, str):
|
269 |
+
bs = len(prompt)
|
270 |
+
img = img.unfold(2, 2, 2).unfold(3, 2, 2).permute(0, 2, 3, 1, 4, 5)
|
271 |
+
img = img.reshape(img.shape[0], -1, img.shape[3] * img.shape[4] * img.shape[5])
|
272 |
+
assert img.shape == (
|
273 |
+
bs,
|
274 |
+
(h // 2) * (w // 2),
|
275 |
+
c * 2 * 2,
|
276 |
+
), f"{img.shape} != {(bs, (h//2)*(w//2), c*2*2)}"
|
277 |
+
if img.shape[0] == 1 and bs > 1:
|
278 |
+
img = img[None].repeat_interleave(bs, dim=0)
|
279 |
+
|
280 |
+
img_ids = torch.zeros(
|
281 |
+
h // 2, w // 2, 3, device=target_device, dtype=target_dtype
|
282 |
+
)
|
283 |
+
img_ids[..., 1] = (
|
284 |
+
img_ids[..., 1]
|
285 |
+
+ torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None]
|
286 |
+
)
|
287 |
+
img_ids[..., 2] = (
|
288 |
+
img_ids[..., 2]
|
289 |
+
+ torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :]
|
290 |
+
)
|
291 |
+
|
292 |
+
img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
|
293 |
+
if self.offload_text_encoder:
|
294 |
+
self.clip.to(device=self.device_clip)
|
295 |
+
self.t5.to(device=self.device_t5)
|
296 |
+
|
297 |
+
# get the text embeddings
|
298 |
+
vec, txt, txt_ids = get_weighted_text_embeddings_flux(
|
299 |
+
self,
|
300 |
+
prompt,
|
301 |
+
num_images_per_prompt=bs,
|
302 |
+
device=self.device_clip,
|
303 |
+
target_device=target_device,
|
304 |
+
target_dtype=target_dtype,
|
305 |
+
debug=self.debug,
|
306 |
+
)
|
307 |
+
# offload text encoder to cpu if needed
|
308 |
+
if self.offload_text_encoder:
|
309 |
+
self.clip.to("cpu")
|
310 |
+
self.t5.to("cpu")
|
311 |
+
torch.cuda.empty_cache()
|
312 |
+
return img, img_ids, vec, txt, txt_ids
|
313 |
+
|
314 |
+
@torch.inference_mode()
|
315 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
316 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
317 |
+
|
318 |
+
def get_lin_function(
|
319 |
+
self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
320 |
+
) -> Callable[[float], float]:
|
321 |
+
m = (y2 - y1) / (x2 - x1)
|
322 |
+
b = y1 - m * x1
|
323 |
+
return lambda x: m * x + b
|
324 |
+
|
325 |
+
@torch.inference_mode()
|
326 |
+
def get_schedule(
|
327 |
+
self,
|
328 |
+
num_steps: int,
|
329 |
+
image_seq_len: int,
|
330 |
+
base_shift: float = 0.5,
|
331 |
+
max_shift: float = 1.15,
|
332 |
+
shift: bool = True,
|
333 |
+
) -> list[float]:
|
334 |
+
"""Generates a schedule of timesteps for the given number of steps and image sequence length."""
|
335 |
+
# extra step for zero
|
336 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
337 |
+
|
338 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
339 |
+
if shift:
|
340 |
+
# eastimate mu based on linear estimation between two points
|
341 |
+
mu = self.get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
342 |
+
timesteps = self.time_shift(mu, 1.0, timesteps)
|
343 |
+
|
344 |
+
return timesteps.tolist()
|
345 |
+
|
346 |
+
@torch.inference_mode()
|
347 |
+
def get_noise(
|
348 |
+
self,
|
349 |
+
num_samples: int,
|
350 |
+
height: int,
|
351 |
+
width: int,
|
352 |
+
generator: torch.Generator,
|
353 |
+
dtype=None,
|
354 |
+
device=None,
|
355 |
+
) -> torch.Tensor:
|
356 |
+
"""Generates a latent noise tensor of the given shape and dtype on the given device."""
|
357 |
+
if device is None:
|
358 |
+
device = self.device_flux
|
359 |
+
if dtype is None:
|
360 |
+
dtype = self.dtype
|
361 |
+
return torch.randn(
|
362 |
+
num_samples,
|
363 |
+
16,
|
364 |
+
# allow for packing
|
365 |
+
2 * math.ceil(height / 16),
|
366 |
+
2 * math.ceil(width / 16),
|
367 |
+
device=device,
|
368 |
+
dtype=dtype,
|
369 |
+
generator=generator,
|
370 |
+
requires_grad=False,
|
371 |
+
)
|
372 |
+
|
373 |
+
@torch.inference_mode()
|
374 |
+
def into_bytes(self, x: torch.Tensor, jpeg_quality: int = 99) -> io.BytesIO:
|
375 |
+
"""Converts the image tensor to bytes."""
|
376 |
+
# bring into PIL format and save
|
377 |
+
num_images = x.shape[0]
|
378 |
+
images: List[torch.Tensor] = []
|
379 |
+
for i in range(num_images):
|
380 |
+
x = (
|
381 |
+
x[i]
|
382 |
+
.clamp(-1, 1)
|
383 |
+
.add(1.0)
|
384 |
+
.mul(127.5)
|
385 |
+
.clamp(0, 255)
|
386 |
+
.contiguous()
|
387 |
+
.type(torch.uint8)
|
388 |
+
)
|
389 |
+
images.append(x)
|
390 |
+
if len(images) == 1:
|
391 |
+
im = images[0]
|
392 |
+
else:
|
393 |
+
im = torch.vstack(images)
|
394 |
+
|
395 |
+
im = self.img_encoder.encode_torch(im, quality=jpeg_quality)
|
396 |
+
images.clear()
|
397 |
+
return im
|
398 |
+
|
399 |
+
@torch.inference_mode()
|
400 |
+
def load_init_image_if_needed(
|
401 |
+
self, init_image: torch.Tensor | str | Image.Image | np.ndarray
|
402 |
+
) -> torch.Tensor:
|
403 |
+
"""
|
404 |
+
Loads the initial image if it is a string, numpy array, or PIL.Image,
|
405 |
+
if torch.Tensor, expects it to be in the correct format and returns it as is.
|
406 |
+
"""
|
407 |
+
if isinstance(init_image, str):
|
408 |
+
try:
|
409 |
+
init_image = Image.open(init_image)
|
410 |
+
except Exception as e:
|
411 |
+
init_image = Image.open(
|
412 |
+
io.BytesIO(standard_b64decode(init_image.split(",")[-1]))
|
413 |
+
)
|
414 |
+
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
|
415 |
+
elif isinstance(init_image, np.ndarray):
|
416 |
+
init_image = torch.from_numpy(init_image).type(torch.uint8)
|
417 |
+
elif isinstance(init_image, Image.Image):
|
418 |
+
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
|
419 |
+
|
420 |
+
return init_image
|
421 |
+
|
422 |
+
@torch.inference_mode()
|
423 |
+
def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
424 |
+
"""Decodes the latent tensor to the pixel space."""
|
425 |
+
if self.offload_vae:
|
426 |
+
self.ae.to(self.device_ae)
|
427 |
+
x = x.to(self.device_ae)
|
428 |
+
else:
|
429 |
+
x = x.to(self.device_ae)
|
430 |
+
x = self.unpack(x.float(), height, width)
|
431 |
+
with torch.autocast(
|
432 |
+
device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
|
433 |
+
):
|
434 |
+
x = self.ae.decode(x)
|
435 |
+
if self.offload_vae:
|
436 |
+
self.ae.to("cpu")
|
437 |
+
torch.cuda.empty_cache()
|
438 |
+
return x
|
439 |
+
|
440 |
+
def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
441 |
+
return rearrange(
|
442 |
+
x,
|
443 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
444 |
+
h=math.ceil(height / 16),
|
445 |
+
w=math.ceil(width / 16),
|
446 |
+
ph=2,
|
447 |
+
pw=2,
|
448 |
+
)
|
449 |
+
|
450 |
+
@torch.inference_mode()
|
451 |
+
def resize_center_crop(
|
452 |
+
self, img: torch.Tensor, height: int, width: int
|
453 |
+
) -> torch.Tensor:
|
454 |
+
"""Resizes and crops the image to the given height and width."""
|
455 |
+
img = TF.resize(img, min(width, height))
|
456 |
+
img = TF.center_crop(img, (height, width))
|
457 |
+
return img
|
458 |
+
|
459 |
+
@torch.inference_mode()
|
460 |
+
def preprocess_latent(
|
461 |
+
self,
|
462 |
+
init_image: torch.Tensor | np.ndarray = None,
|
463 |
+
height: int = 720,
|
464 |
+
width: int = 1024,
|
465 |
+
num_steps: int = 20,
|
466 |
+
strength: float = 1.0,
|
467 |
+
generator: torch.Generator = None,
|
468 |
+
num_images: int = 1,
|
469 |
+
) -> tuple[torch.Tensor, List[float]]:
|
470 |
+
"""
|
471 |
+
Preprocesses the latent tensor for the given number of steps and image sequence length.
|
472 |
+
Also, if an initial image is provided, it is vae encoded and injected with the appropriate noise
|
473 |
+
given the strength and number of steps replacing the latent tensor.
|
474 |
+
"""
|
475 |
+
# prepare input
|
476 |
+
|
477 |
+
if init_image is not None:
|
478 |
+
if isinstance(init_image, np.ndarray):
|
479 |
+
init_image = torch.from_numpy(init_image)
|
480 |
+
|
481 |
+
init_image = (
|
482 |
+
init_image.permute(2, 0, 1)
|
483 |
+
.contiguous()
|
484 |
+
.to(self.device_ae, dtype=self.ae_dtype)
|
485 |
+
.div(127.5)
|
486 |
+
.sub(1)[None, ...]
|
487 |
+
)
|
488 |
+
init_image = self.resize_center_crop(init_image, height, width)
|
489 |
+
with torch.autocast(
|
490 |
+
device_type=self.device_ae.type,
|
491 |
+
dtype=torch.bfloat16,
|
492 |
+
cache_enabled=False,
|
493 |
+
):
|
494 |
+
if self.offload_vae:
|
495 |
+
self.ae.to(self.device_ae)
|
496 |
+
init_image = (
|
497 |
+
self.ae.encode(init_image)
|
498 |
+
.to(dtype=self.dtype, device=self.device_flux)
|
499 |
+
.repeat(num_images, 1, 1, 1)
|
500 |
+
)
|
501 |
+
if self.offload_vae:
|
502 |
+
self.ae.to("cpu")
|
503 |
+
torch.cuda.empty_cache()
|
504 |
+
|
505 |
+
x = self.get_noise(
|
506 |
+
num_images,
|
507 |
+
height,
|
508 |
+
width,
|
509 |
+
device=self.device_flux,
|
510 |
+
dtype=self.dtype,
|
511 |
+
generator=generator,
|
512 |
+
)
|
513 |
+
timesteps = self.get_schedule(
|
514 |
+
num_steps=num_steps,
|
515 |
+
image_seq_len=x.shape[-1] * x.shape[-2] // 4,
|
516 |
+
shift=(self.name != "flux-schnell"),
|
517 |
+
)
|
518 |
+
if init_image is not None:
|
519 |
+
t_idx = int((1 - strength) * num_steps)
|
520 |
+
t = timesteps[t_idx]
|
521 |
+
timesteps = timesteps[t_idx:]
|
522 |
+
x = t * x + (1.0 - t) * init_image
|
523 |
+
return x, timesteps
|
524 |
+
|
525 |
+
@torch.inference_mode()
|
526 |
+
def generate(
|
527 |
+
self,
|
528 |
+
prompt: str,
|
529 |
+
width: int = 720,
|
530 |
+
height: int = 1024,
|
531 |
+
num_steps: int = 24,
|
532 |
+
guidance: float = 3.5,
|
533 |
+
seed: int | None = None,
|
534 |
+
init_image: torch.Tensor | str | Image.Image | np.ndarray | None = None,
|
535 |
+
strength: float = 1.0,
|
536 |
+
silent: bool = False,
|
537 |
+
num_images: int = 1,
|
538 |
+
return_seed: bool = False,
|
539 |
+
jpeg_quality: int = 99,
|
540 |
+
) -> io.BytesIO:
|
541 |
+
"""
|
542 |
+
Generate images based on the given prompt and parameters.
|
543 |
+
|
544 |
+
Args:
|
545 |
+
prompt `(str)`: The text prompt to guide the image generation.
|
546 |
+
|
547 |
+
width `(int, optional)`: Width of the generated image. Defaults to 720.
|
548 |
+
|
549 |
+
height `(int, optional)`: Height of the generated image. Defaults to 1024.
|
550 |
+
|
551 |
+
num_steps `(int, optional)`: Number of denoising steps. Defaults to 24.
|
552 |
+
|
553 |
+
guidance `(float, optional)`: Guidance scale for text-to-image generation. Defaults to 3.5.
|
554 |
+
|
555 |
+
seed `(int | None, optional)`: Random seed for reproducibility. If None, a random seed is used. Defaults to None.
|
556 |
+
|
557 |
+
init_image `(torch.Tensor | str | Image.Image | np.ndarray | None, optional)`: Initial image for image-to-image generation. Defaults to None.
|
558 |
+
|
559 |
+
-- note: if the image's height/width do not match the height/width of the generated image, the image is resized and centered cropped to match the height/width arguments.
|
560 |
+
|
561 |
+
-- If a string is provided, it is assumed to be either a path to an image file or a base64 encoded image.
|
562 |
+
|
563 |
+
-- If a numpy array is provided, it is assumed to be an RGB numpy array of shape (height, width, 3) and dtype uint8.
|
564 |
+
|
565 |
+
-- If a PIL.Image is provided, it is assumed to be an RGB PIL.Image.
|
566 |
+
|
567 |
+
-- If a torch.Tensor is provided, it is assumed to be a torch.Tensor of shape (height, width, 3) and dtype uint8 with range [0, 255].
|
568 |
+
|
569 |
+
strength `(float, optional)`: Strength of the init_image in image-to-image generation. Defaults to 1.0.
|
570 |
+
|
571 |
+
silent `(bool, optional)`: If True, suppresses progress bar. Defaults to False.
|
572 |
+
|
573 |
+
num_images `(int, optional)`: Number of images to generate. Defaults to 1.
|
574 |
+
|
575 |
+
return_seed `(bool, optional)`: If True, returns the seed along with the generated image. Defaults to False.
|
576 |
+
|
577 |
+
jpeg_quality `(int, optional)`: Quality of the JPEG compression. Defaults to 99.
|
578 |
+
|
579 |
+
Returns:
|
580 |
+
io.BytesIO: Generated image(s) in bytes format.
|
581 |
+
int: Seed used for generation (only if return_seed is True).
|
582 |
+
"""
|
583 |
+
num_steps = 4 if self.name == "flux-schnell" else num_steps
|
584 |
+
|
585 |
+
init_image = self.load_init_image_if_needed(init_image)
|
586 |
+
|
587 |
+
# allow for packing and conversion to latent space
|
588 |
+
height = 16 * (height // 16)
|
589 |
+
width = 16 * (width // 16)
|
590 |
+
|
591 |
+
generator, seed = self.set_seed(seed)
|
592 |
+
|
593 |
+
if not silent:
|
594 |
+
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
|
595 |
+
|
596 |
+
# preprocess the latent
|
597 |
+
img, timesteps = self.preprocess_latent(
|
598 |
+
init_image=init_image,
|
599 |
+
height=height,
|
600 |
+
width=width,
|
601 |
+
num_steps=num_steps,
|
602 |
+
strength=strength,
|
603 |
+
generator=generator,
|
604 |
+
num_images=num_images,
|
605 |
+
)
|
606 |
+
|
607 |
+
# prepare inputs
|
608 |
+
img, img_ids, vec, txt, txt_ids = map(
|
609 |
+
lambda x: x.contiguous(),
|
610 |
+
self.prepare(
|
611 |
+
img=img,
|
612 |
+
prompt=prompt,
|
613 |
+
target_device=self.device_flux,
|
614 |
+
target_dtype=self.dtype,
|
615 |
+
),
|
616 |
+
)
|
617 |
+
|
618 |
+
# this is ignored for schnell
|
619 |
+
guidance_vec = torch.full(
|
620 |
+
(img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
|
621 |
+
)
|
622 |
+
t_vec = None
|
623 |
+
# dispatch to gpu if offloaded
|
624 |
+
if self.offload_flow:
|
625 |
+
self.model.to(self.device_flux)
|
626 |
+
|
627 |
+
# perform the denoising loop
|
628 |
+
for t_curr, t_prev in tqdm(
|
629 |
+
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
|
630 |
+
):
|
631 |
+
if t_vec is None:
|
632 |
+
t_vec = torch.full(
|
633 |
+
(img.shape[0],),
|
634 |
+
t_curr,
|
635 |
+
dtype=self.dtype,
|
636 |
+
device=self.device_flux,
|
637 |
+
)
|
638 |
+
else:
|
639 |
+
t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
|
640 |
+
|
641 |
+
pred = self.model.forward(
|
642 |
+
img=img,
|
643 |
+
img_ids=img_ids,
|
644 |
+
txt=txt,
|
645 |
+
txt_ids=txt_ids,
|
646 |
+
y=vec,
|
647 |
+
timesteps=t_vec,
|
648 |
+
guidance=guidance_vec,
|
649 |
+
)
|
650 |
+
|
651 |
+
img = img + (t_prev - t_curr) * pred
|
652 |
+
|
653 |
+
# offload the model to cpu if needed
|
654 |
+
if self.offload_flow:
|
655 |
+
self.model.to("cpu")
|
656 |
+
torch.cuda.empty_cache()
|
657 |
+
|
658 |
+
# decode latents to pixel space
|
659 |
+
img = self.vae_decode(img, height, width)
|
660 |
+
|
661 |
+
if return_seed:
|
662 |
+
return self.into_bytes(img, jpeg_quality=jpeg_quality), seed
|
663 |
+
return self.into_bytes(img, jpeg_quality=jpeg_quality)
|
664 |
+
|
665 |
+
@classmethod
|
666 |
+
def load_pipeline_from_config_path(
|
667 |
+
cls, path: str, flow_model_path: str = None, debug: bool = False, **kwargs
|
668 |
+
) -> "FluxPipeline":
|
669 |
+
with torch.inference_mode():
|
670 |
+
config = load_config_from_path(path)
|
671 |
+
if flow_model_path:
|
672 |
+
config.ckpt_path = flow_model_path
|
673 |
+
for k, v in kwargs.items():
|
674 |
+
if hasattr(config, k):
|
675 |
+
logger.info(
|
676 |
+
f"Overriding config {k}:{getattr(config, k)} with value {v}"
|
677 |
+
)
|
678 |
+
setattr(config, k, v)
|
679 |
+
return cls.load_pipeline_from_config(config, debug=debug)
|
680 |
+
|
681 |
+
@classmethod
|
682 |
+
def load_pipeline_from_config(
|
683 |
+
cls, config: ModelSpec, debug: bool = False
|
684 |
+
) -> "FluxPipeline":
|
685 |
+
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
|
686 |
+
|
687 |
+
with torch.inference_mode():
|
688 |
+
if debug:
|
689 |
+
logger.info(
|
690 |
+
f"Loading as prequantized flow transformer? {config.prequantized_flow}"
|
691 |
+
)
|
692 |
+
|
693 |
+
models = load_models_from_config(config)
|
694 |
+
config = models.config
|
695 |
+
flux_device = into_device(config.flux_device)
|
696 |
+
ae_device = into_device(config.ae_device)
|
697 |
+
clip_device = into_device(config.text_enc_device)
|
698 |
+
t5_device = into_device(config.text_enc_device)
|
699 |
+
flux_dtype = into_dtype(config.flow_dtype)
|
700 |
+
flow_model = models.flow
|
701 |
+
|
702 |
+
if not config.prequantized_flow:
|
703 |
+
flow_model = quantize_flow_transformer_and_dispatch_float8(
|
704 |
+
flow_model,
|
705 |
+
flux_device,
|
706 |
+
offload_flow=config.offload_flow,
|
707 |
+
swap_linears_with_cublaslinear=flux_dtype == torch.float16,
|
708 |
+
flow_dtype=flux_dtype,
|
709 |
+
quantize_modulation=config.quantize_modulation,
|
710 |
+
quantize_flow_embedder_layers=config.quantize_flow_embedder_layers,
|
711 |
+
)
|
712 |
+
else:
|
713 |
+
flow_model.eval().requires_grad_(False)
|
714 |
+
|
715 |
+
return cls(
|
716 |
+
name=config.version,
|
717 |
+
clip=models.clip,
|
718 |
+
t5=models.t5,
|
719 |
+
model=flow_model,
|
720 |
+
ae=models.ae,
|
721 |
+
dtype=flux_dtype,
|
722 |
+
verbose=False,
|
723 |
+
flux_device=flux_device,
|
724 |
+
ae_device=ae_device,
|
725 |
+
clip_device=clip_device,
|
726 |
+
t5_device=t5_device,
|
727 |
+
config=config,
|
728 |
+
debug=debug,
|
729 |
+
)
|
image_encoder.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class ImageEncoder:
|
8 |
+
|
9 |
+
@torch.inference_mode()
|
10 |
+
def encode_torch(self, img: torch.Tensor, quality=95):
|
11 |
+
if img.ndim == 2:
|
12 |
+
img = (
|
13 |
+
img[None]
|
14 |
+
.repeat_interleave(3, dim=0)
|
15 |
+
.permute(1, 2, 0)
|
16 |
+
.contiguous()
|
17 |
+
.clamp(0, 255)
|
18 |
+
.type(torch.uint8)
|
19 |
+
)
|
20 |
+
elif img.ndim == 3:
|
21 |
+
if img.shape[0] == 3:
|
22 |
+
img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8)
|
23 |
+
elif img.shape[2] == 3:
|
24 |
+
img = img.contiguous().clamp(0, 255).type(torch.uint8)
|
25 |
+
else:
|
26 |
+
raise ValueError(f"Unsupported image shape: {img.shape}")
|
27 |
+
else:
|
28 |
+
raise ValueError(f"Unsupported image num dims: {img.ndim}")
|
29 |
+
|
30 |
+
img = img.cpu().numpy().astype(np.uint8)
|
31 |
+
im = Image.fromarray(img)
|
32 |
+
iob = io.BytesIO()
|
33 |
+
im.save(iob, format="JPEG", quality=quality)
|
34 |
+
iob.seek(0)
|
35 |
+
return iob
|
lora_loading.py
ADDED
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Optional, OrderedDict, Tuple, TypeAlias, Union
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
try:
|
10 |
+
from cublas_ops import CublasLinear
|
11 |
+
except Exception as e:
|
12 |
+
CublasLinear = type(None)
|
13 |
+
from float8_quantize import F8Linear
|
14 |
+
from modules.flux_model import Flux
|
15 |
+
|
16 |
+
path_regex = re.compile(r"/|\\")
|
17 |
+
|
18 |
+
StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
|
19 |
+
|
20 |
+
|
21 |
+
class LoraWeights:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
weights: StateDict,
|
25 |
+
path: str,
|
26 |
+
name: str = None,
|
27 |
+
scale: float = 1.0,
|
28 |
+
) -> None:
|
29 |
+
self.path = path
|
30 |
+
self.weights = weights
|
31 |
+
self.name = name if name else path_regex.split(path)[-1]
|
32 |
+
self.scale = scale
|
33 |
+
|
34 |
+
|
35 |
+
def swap_scale_shift(weight):
|
36 |
+
scale, shift = weight.chunk(2, dim=0)
|
37 |
+
new_weight = torch.cat([shift, scale], dim=0)
|
38 |
+
return new_weight
|
39 |
+
|
40 |
+
|
41 |
+
def check_if_lora_exists(state_dict, lora_name):
|
42 |
+
subkey = lora_name.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]
|
43 |
+
for key in state_dict.keys():
|
44 |
+
if subkey in key:
|
45 |
+
return subkey
|
46 |
+
return False
|
47 |
+
|
48 |
+
|
49 |
+
def convert_if_lora_exists(new_state_dict, state_dict, lora_name, flux_layer_name):
|
50 |
+
if (original_stubkey := check_if_lora_exists(state_dict, lora_name)) != False:
|
51 |
+
weights_to_pop = [k for k in state_dict.keys() if original_stubkey in k]
|
52 |
+
for key in weights_to_pop:
|
53 |
+
key_replacement = key.replace(
|
54 |
+
original_stubkey, flux_layer_name.replace(".weight", "")
|
55 |
+
)
|
56 |
+
new_state_dict[key_replacement] = state_dict.pop(key)
|
57 |
+
return new_state_dict, state_dict
|
58 |
+
else:
|
59 |
+
return new_state_dict, state_dict
|
60 |
+
|
61 |
+
|
62 |
+
def convert_diffusers_to_flux_transformer_checkpoint(
|
63 |
+
diffusers_state_dict,
|
64 |
+
num_layers,
|
65 |
+
num_single_layers,
|
66 |
+
has_guidance=True,
|
67 |
+
prefix="",
|
68 |
+
):
|
69 |
+
original_state_dict = {}
|
70 |
+
|
71 |
+
# time_text_embed.timestep_embedder -> time_in
|
72 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
73 |
+
original_state_dict,
|
74 |
+
diffusers_state_dict,
|
75 |
+
f"{prefix}time_text_embed.timestep_embedder.linear_1.weight",
|
76 |
+
"time_in.in_layer.weight",
|
77 |
+
)
|
78 |
+
# time_text_embed.text_embedder -> vector_in
|
79 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
80 |
+
original_state_dict,
|
81 |
+
diffusers_state_dict,
|
82 |
+
f"{prefix}time_text_embed.text_embedder.linear_1.weight",
|
83 |
+
"vector_in.in_layer.weight",
|
84 |
+
)
|
85 |
+
|
86 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
87 |
+
original_state_dict,
|
88 |
+
diffusers_state_dict,
|
89 |
+
f"{prefix}time_text_embed.text_embedder.linear_2.weight",
|
90 |
+
"vector_in.out_layer.weight",
|
91 |
+
)
|
92 |
+
|
93 |
+
if has_guidance:
|
94 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
95 |
+
original_state_dict,
|
96 |
+
diffusers_state_dict,
|
97 |
+
f"{prefix}time_text_embed.guidance_embedder.linear_1.weight",
|
98 |
+
"guidance_in.in_layer.weight",
|
99 |
+
)
|
100 |
+
|
101 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
102 |
+
original_state_dict,
|
103 |
+
diffusers_state_dict,
|
104 |
+
f"{prefix}time_text_embed.guidance_embedder.linear_2.weight",
|
105 |
+
"guidance_in.out_layer.weight",
|
106 |
+
)
|
107 |
+
|
108 |
+
# context_embedder -> txt_in
|
109 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
110 |
+
original_state_dict,
|
111 |
+
diffusers_state_dict,
|
112 |
+
f"{prefix}context_embedder.weight",
|
113 |
+
"txt_in.weight",
|
114 |
+
)
|
115 |
+
|
116 |
+
# x_embedder -> img_in
|
117 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
118 |
+
original_state_dict,
|
119 |
+
diffusers_state_dict,
|
120 |
+
f"{prefix}x_embedder.weight",
|
121 |
+
"img_in.weight",
|
122 |
+
)
|
123 |
+
# double transformer blocks
|
124 |
+
for i in range(num_layers):
|
125 |
+
block_prefix = f"transformer_blocks.{i}."
|
126 |
+
# norms
|
127 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
128 |
+
original_state_dict,
|
129 |
+
diffusers_state_dict,
|
130 |
+
f"{prefix}{block_prefix}norm1.linear.weight",
|
131 |
+
f"double_blocks.{i}.img_mod.lin.weight",
|
132 |
+
)
|
133 |
+
|
134 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
135 |
+
original_state_dict,
|
136 |
+
diffusers_state_dict,
|
137 |
+
f"{prefix}{block_prefix}norm1_context.linear.weight",
|
138 |
+
f"double_blocks.{i}.txt_mod.lin.weight",
|
139 |
+
)
|
140 |
+
|
141 |
+
# Q, K, V
|
142 |
+
temp_dict = {}
|
143 |
+
|
144 |
+
expected_shape_qkv_a = None
|
145 |
+
expected_shape_qkv_b = None
|
146 |
+
expected_shape_add_qkv_a = None
|
147 |
+
expected_shape_add_qkv_b = None
|
148 |
+
dtype = None
|
149 |
+
device = None
|
150 |
+
|
151 |
+
for component in [
|
152 |
+
"to_q",
|
153 |
+
"to_k",
|
154 |
+
"to_v",
|
155 |
+
"add_q_proj",
|
156 |
+
"add_k_proj",
|
157 |
+
"add_v_proj",
|
158 |
+
]:
|
159 |
+
|
160 |
+
sample_component_A_key = (
|
161 |
+
f"{prefix}{block_prefix}attn.{component}.lora_A.weight"
|
162 |
+
)
|
163 |
+
sample_component_B_key = (
|
164 |
+
f"{prefix}{block_prefix}attn.{component}.lora_B.weight"
|
165 |
+
)
|
166 |
+
if (
|
167 |
+
sample_component_A_key in diffusers_state_dict
|
168 |
+
and sample_component_B_key in diffusers_state_dict
|
169 |
+
):
|
170 |
+
sample_component_A = diffusers_state_dict.pop(sample_component_A_key)
|
171 |
+
sample_component_B = diffusers_state_dict.pop(sample_component_B_key)
|
172 |
+
temp_dict[f"{component}"] = [sample_component_A, sample_component_B]
|
173 |
+
if expected_shape_qkv_a is None and not component.startswith("add_"):
|
174 |
+
expected_shape_qkv_a = sample_component_A.shape
|
175 |
+
expected_shape_qkv_b = sample_component_B.shape
|
176 |
+
dtype = sample_component_A.dtype
|
177 |
+
device = sample_component_A.device
|
178 |
+
if expected_shape_add_qkv_a is None and component.startswith("add_"):
|
179 |
+
expected_shape_add_qkv_a = sample_component_A.shape
|
180 |
+
expected_shape_add_qkv_b = sample_component_B.shape
|
181 |
+
dtype = sample_component_A.dtype
|
182 |
+
device = sample_component_A.device
|
183 |
+
else:
|
184 |
+
logger.info(
|
185 |
+
f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}"
|
186 |
+
)
|
187 |
+
temp_dict[f"{component}"] = [None, None]
|
188 |
+
|
189 |
+
if device is not None:
|
190 |
+
if expected_shape_qkv_a is not None:
|
191 |
+
|
192 |
+
if (sq := temp_dict["to_q"])[0] is not None:
|
193 |
+
sample_q_A, sample_q_B = sq
|
194 |
+
else:
|
195 |
+
sample_q_A, sample_q_B = [
|
196 |
+
torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
|
197 |
+
torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
|
198 |
+
]
|
199 |
+
if (sq := temp_dict["to_k"])[0] is not None:
|
200 |
+
sample_k_A, sample_k_B = sq
|
201 |
+
else:
|
202 |
+
sample_k_A, sample_k_B = [
|
203 |
+
torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
|
204 |
+
torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
|
205 |
+
]
|
206 |
+
if (sq := temp_dict["to_v"])[0] is not None:
|
207 |
+
sample_v_A, sample_v_B = sq
|
208 |
+
else:
|
209 |
+
sample_v_A, sample_v_B = [
|
210 |
+
torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
|
211 |
+
torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
|
212 |
+
]
|
213 |
+
original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = (
|
214 |
+
torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0)
|
215 |
+
)
|
216 |
+
original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = (
|
217 |
+
torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0)
|
218 |
+
)
|
219 |
+
if expected_shape_add_qkv_a is not None:
|
220 |
+
|
221 |
+
if (sq := temp_dict["add_q_proj"])[0] is not None:
|
222 |
+
context_q_A, context_q_B = sq
|
223 |
+
else:
|
224 |
+
context_q_A, context_q_B = [
|
225 |
+
torch.zeros(
|
226 |
+
expected_shape_add_qkv_a, dtype=dtype, device=device
|
227 |
+
),
|
228 |
+
torch.zeros(
|
229 |
+
expected_shape_add_qkv_b, dtype=dtype, device=device
|
230 |
+
),
|
231 |
+
]
|
232 |
+
if (sq := temp_dict["add_k_proj"])[0] is not None:
|
233 |
+
context_k_A, context_k_B = sq
|
234 |
+
else:
|
235 |
+
context_k_A, context_k_B = [
|
236 |
+
torch.zeros(
|
237 |
+
expected_shape_add_qkv_a, dtype=dtype, device=device
|
238 |
+
),
|
239 |
+
torch.zeros(
|
240 |
+
expected_shape_add_qkv_b, dtype=dtype, device=device
|
241 |
+
),
|
242 |
+
]
|
243 |
+
if (sq := temp_dict["add_v_proj"])[0] is not None:
|
244 |
+
context_v_A, context_v_B = sq
|
245 |
+
else:
|
246 |
+
context_v_A, context_v_B = [
|
247 |
+
torch.zeros(
|
248 |
+
expected_shape_add_qkv_a, dtype=dtype, device=device
|
249 |
+
),
|
250 |
+
torch.zeros(
|
251 |
+
expected_shape_add_qkv_b, dtype=dtype, device=device
|
252 |
+
),
|
253 |
+
]
|
254 |
+
|
255 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = (
|
256 |
+
torch.cat([context_q_A, context_k_A, context_v_A], dim=0)
|
257 |
+
)
|
258 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = (
|
259 |
+
torch.cat([context_q_B, context_k_B, context_v_B], dim=0)
|
260 |
+
)
|
261 |
+
|
262 |
+
# qk_norm
|
263 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
264 |
+
original_state_dict,
|
265 |
+
diffusers_state_dict,
|
266 |
+
f"{prefix}{block_prefix}attn.norm_q.weight",
|
267 |
+
f"double_blocks.{i}.img_attn.norm.query_norm.scale",
|
268 |
+
)
|
269 |
+
|
270 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
271 |
+
original_state_dict,
|
272 |
+
diffusers_state_dict,
|
273 |
+
f"{prefix}{block_prefix}attn.norm_k.weight",
|
274 |
+
f"double_blocks.{i}.img_attn.norm.key_norm.scale",
|
275 |
+
)
|
276 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
277 |
+
original_state_dict,
|
278 |
+
diffusers_state_dict,
|
279 |
+
f"{prefix}{block_prefix}attn.norm_added_q.weight",
|
280 |
+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale",
|
281 |
+
)
|
282 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
283 |
+
original_state_dict,
|
284 |
+
diffusers_state_dict,
|
285 |
+
f"{prefix}{block_prefix}attn.norm_added_k.weight",
|
286 |
+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale",
|
287 |
+
)
|
288 |
+
|
289 |
+
# ff img_mlp
|
290 |
+
|
291 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
292 |
+
original_state_dict,
|
293 |
+
diffusers_state_dict,
|
294 |
+
f"{prefix}{block_prefix}ff.net.0.proj.weight",
|
295 |
+
f"double_blocks.{i}.img_mlp.0.weight",
|
296 |
+
)
|
297 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
298 |
+
original_state_dict,
|
299 |
+
diffusers_state_dict,
|
300 |
+
f"{prefix}{block_prefix}ff.net.2.weight",
|
301 |
+
f"double_blocks.{i}.img_mlp.2.weight",
|
302 |
+
)
|
303 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
304 |
+
original_state_dict,
|
305 |
+
diffusers_state_dict,
|
306 |
+
f"{prefix}{block_prefix}ff_context.net.0.proj.weight",
|
307 |
+
f"double_blocks.{i}.txt_mlp.0.weight",
|
308 |
+
)
|
309 |
+
|
310 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
311 |
+
original_state_dict,
|
312 |
+
diffusers_state_dict,
|
313 |
+
f"{prefix}{block_prefix}ff_context.net.2.weight",
|
314 |
+
f"double_blocks.{i}.txt_mlp.2.weight",
|
315 |
+
)
|
316 |
+
# output projections
|
317 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
318 |
+
original_state_dict,
|
319 |
+
diffusers_state_dict,
|
320 |
+
f"{prefix}{block_prefix}attn.to_out.0.weight",
|
321 |
+
f"double_blocks.{i}.img_attn.proj.weight",
|
322 |
+
)
|
323 |
+
|
324 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
325 |
+
original_state_dict,
|
326 |
+
diffusers_state_dict,
|
327 |
+
f"{prefix}{block_prefix}attn.to_add_out.weight",
|
328 |
+
f"double_blocks.{i}.txt_attn.proj.weight",
|
329 |
+
)
|
330 |
+
|
331 |
+
# single transformer blocks
|
332 |
+
for i in range(num_single_layers):
|
333 |
+
block_prefix = f"single_transformer_blocks.{i}."
|
334 |
+
# norm.linear -> single_blocks.0.modulation.lin
|
335 |
+
key_norm = f"{prefix}{block_prefix}norm.linear.weight"
|
336 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
337 |
+
original_state_dict,
|
338 |
+
diffusers_state_dict,
|
339 |
+
key_norm,
|
340 |
+
f"single_blocks.{i}.modulation.lin.weight",
|
341 |
+
)
|
342 |
+
|
343 |
+
has_q, has_k, has_v, has_mlp = False, False, False, False
|
344 |
+
shape_qkv_a = None
|
345 |
+
shape_qkv_b = None
|
346 |
+
# Q, K, V, mlp
|
347 |
+
q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
|
348 |
+
q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
|
349 |
+
if q_A is not None and q_B is not None:
|
350 |
+
has_q = True
|
351 |
+
shape_qkv_a = q_A.shape
|
352 |
+
shape_qkv_b = q_B.shape
|
353 |
+
k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
|
354 |
+
k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
|
355 |
+
if k_A is not None and k_B is not None:
|
356 |
+
has_k = True
|
357 |
+
shape_qkv_a = k_A.shape
|
358 |
+
shape_qkv_b = k_B.shape
|
359 |
+
v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
|
360 |
+
v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
|
361 |
+
if v_A is not None and v_B is not None:
|
362 |
+
has_v = True
|
363 |
+
shape_qkv_a = v_A.shape
|
364 |
+
shape_qkv_b = v_B.shape
|
365 |
+
mlp_A = diffusers_state_dict.pop(
|
366 |
+
f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
|
367 |
+
)
|
368 |
+
mlp_B = diffusers_state_dict.pop(
|
369 |
+
f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
|
370 |
+
)
|
371 |
+
if mlp_A is not None and mlp_B is not None:
|
372 |
+
has_mlp = True
|
373 |
+
shape_qkv_a = mlp_A.shape
|
374 |
+
shape_qkv_b = mlp_B.shape
|
375 |
+
if any([has_q, has_k, has_v, has_mlp]):
|
376 |
+
if not has_q:
|
377 |
+
q_A, q_B = [
|
378 |
+
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
|
379 |
+
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
|
380 |
+
]
|
381 |
+
if not has_k:
|
382 |
+
k_A, k_B = [
|
383 |
+
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
|
384 |
+
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
|
385 |
+
]
|
386 |
+
if not has_v:
|
387 |
+
v_A, v_B = [
|
388 |
+
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
|
389 |
+
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
|
390 |
+
]
|
391 |
+
if not has_mlp:
|
392 |
+
mlp_A, mlp_B = [
|
393 |
+
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
|
394 |
+
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
|
395 |
+
]
|
396 |
+
original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat(
|
397 |
+
[q_A, k_A, v_A, mlp_A], dim=0
|
398 |
+
)
|
399 |
+
original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat(
|
400 |
+
[q_B, k_B, v_B, mlp_B], dim=0
|
401 |
+
)
|
402 |
+
|
403 |
+
# output projections
|
404 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
405 |
+
original_state_dict,
|
406 |
+
diffusers_state_dict,
|
407 |
+
f"{prefix}{block_prefix}proj_out.weight",
|
408 |
+
f"single_blocks.{i}.linear2.weight",
|
409 |
+
)
|
410 |
+
|
411 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
412 |
+
original_state_dict,
|
413 |
+
diffusers_state_dict,
|
414 |
+
f"{prefix}proj_out.weight",
|
415 |
+
"final_layer.linear.weight",
|
416 |
+
)
|
417 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
418 |
+
original_state_dict,
|
419 |
+
diffusers_state_dict,
|
420 |
+
f"{prefix}proj_out.bias",
|
421 |
+
"final_layer.linear.bias",
|
422 |
+
)
|
423 |
+
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
|
424 |
+
original_state_dict,
|
425 |
+
diffusers_state_dict,
|
426 |
+
f"{prefix}norm_out.linear.weight",
|
427 |
+
"final_layer.adaLN_modulation.1.weight",
|
428 |
+
)
|
429 |
+
if len(list(diffusers_state_dict.keys())) > 0:
|
430 |
+
logger.warning("Unexpected keys:", diffusers_state_dict.keys())
|
431 |
+
|
432 |
+
return original_state_dict
|
433 |
+
|
434 |
+
|
435 |
+
def convert_from_original_flux_checkpoint(original_state_dict: StateDict) -> StateDict:
|
436 |
+
"""
|
437 |
+
Convert the state dict from the original Flux checkpoint format to the new format.
|
438 |
+
|
439 |
+
Args:
|
440 |
+
original_state_dict (Dict[str, torch.Tensor]): The original Flux checkpoint state dict.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
Dict[str, torch.Tensor]: The converted state dict in the new format.
|
444 |
+
"""
|
445 |
+
sd = {
|
446 |
+
k.replace("lora_unet_", "")
|
447 |
+
.replace("double_blocks_", "double_blocks.")
|
448 |
+
.replace("single_blocks_", "single_blocks.")
|
449 |
+
.replace("_img_attn_", ".img_attn.")
|
450 |
+
.replace("_txt_attn_", ".txt_attn.")
|
451 |
+
.replace("_img_mod_", ".img_mod.")
|
452 |
+
.replace("_txt_mod_", ".txt_mod.")
|
453 |
+
.replace("_img_mlp_", ".img_mlp.")
|
454 |
+
.replace("_txt_mlp_", ".txt_mlp.")
|
455 |
+
.replace("_linear1", ".linear1")
|
456 |
+
.replace("_linear2", ".linear2")
|
457 |
+
.replace("_modulation_", ".modulation.")
|
458 |
+
.replace("lora_up", "lora_B")
|
459 |
+
.replace("lora_down", "lora_A"): v
|
460 |
+
for k, v in original_state_dict.items()
|
461 |
+
if "lora" in k
|
462 |
+
}
|
463 |
+
return sd
|
464 |
+
|
465 |
+
|
466 |
+
def get_module_for_key(
|
467 |
+
key: str, model: Flux
|
468 |
+
) -> F8Linear | torch.nn.Linear | CublasLinear:
|
469 |
+
parts = key.split(".")
|
470 |
+
module = model
|
471 |
+
for part in parts:
|
472 |
+
module = getattr(module, part)
|
473 |
+
return module
|
474 |
+
|
475 |
+
|
476 |
+
def get_lora_for_key(
|
477 |
+
key: str, lora_weights: dict
|
478 |
+
) -> Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]:
|
479 |
+
"""
|
480 |
+
Get LoRA weights for a specific key.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
key (str): The key to look up in the LoRA weights.
|
484 |
+
lora_weights (dict): Dictionary containing LoRA weights.
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: A tuple containing lora_A, lora_B, and alpha if found, None otherwise.
|
488 |
+
"""
|
489 |
+
prefix = key.split(".lora")[0]
|
490 |
+
lora_A = lora_weights.get(f"{prefix}.lora_A.weight")
|
491 |
+
lora_B = lora_weights.get(f"{prefix}.lora_B.weight")
|
492 |
+
alpha = lora_weights.get(f"{prefix}.alpha")
|
493 |
+
|
494 |
+
if lora_A is None or lora_B is None:
|
495 |
+
return None
|
496 |
+
return lora_A, lora_B, alpha
|
497 |
+
|
498 |
+
|
499 |
+
def get_module_for_key(
|
500 |
+
key: str, model: Flux
|
501 |
+
) -> F8Linear | torch.nn.Linear | CublasLinear:
|
502 |
+
parts = key.split(".")
|
503 |
+
module = model
|
504 |
+
for part in parts:
|
505 |
+
module = getattr(module, part)
|
506 |
+
return module
|
507 |
+
|
508 |
+
|
509 |
+
def calculate_lora_weight(
|
510 |
+
lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
|
511 |
+
rank: Optional[int] = None,
|
512 |
+
lora_scale: float = 1.0,
|
513 |
+
device: Optional[Union[torch.device, int, str]] = None,
|
514 |
+
):
|
515 |
+
lora_A, lora_B, alpha = lora_weights
|
516 |
+
if device is None:
|
517 |
+
device = lora_A.device
|
518 |
+
|
519 |
+
uneven_rank = lora_B.shape[1] != lora_A.shape[0]
|
520 |
+
rank_diff = lora_A.shape[0] / lora_B.shape[1]
|
521 |
+
|
522 |
+
if rank is None:
|
523 |
+
rank = lora_B.shape[1]
|
524 |
+
if alpha is None:
|
525 |
+
alpha = rank
|
526 |
+
|
527 |
+
dtype = torch.float32
|
528 |
+
w_up = lora_A.to(dtype=dtype, device=device)
|
529 |
+
w_down = lora_B.to(dtype=dtype, device=device)
|
530 |
+
|
531 |
+
if alpha != rank:
|
532 |
+
w_up = w_up * alpha / rank
|
533 |
+
if uneven_rank:
|
534 |
+
# Fuse each lora instead of repeat interleave for each individual lora,
|
535 |
+
# seems to fuse more correctly.
|
536 |
+
fused_lora = torch.zeros(
|
537 |
+
(lora_B.shape[0], lora_A.shape[1]), device=device, dtype=dtype
|
538 |
+
)
|
539 |
+
w_up = w_up.chunk(int(rank_diff), dim=0)
|
540 |
+
for w_up_chunk in w_up:
|
541 |
+
fused_lora = fused_lora + (lora_scale * torch.mm(w_down, w_up_chunk))
|
542 |
+
else:
|
543 |
+
fused_lora = lora_scale * torch.mm(w_down, w_up)
|
544 |
+
return fused_lora
|
545 |
+
|
546 |
+
|
547 |
+
@torch.inference_mode()
|
548 |
+
def unfuse_lora_weight_from_module(
|
549 |
+
fused_weight: torch.Tensor,
|
550 |
+
lora_weights: dict,
|
551 |
+
rank: Optional[int] = None,
|
552 |
+
lora_scale: float = 1.0,
|
553 |
+
):
|
554 |
+
w_dtype = fused_weight.dtype
|
555 |
+
dtype = torch.float32
|
556 |
+
device = fused_weight.device
|
557 |
+
|
558 |
+
fused_weight = fused_weight.to(dtype=dtype, device=device)
|
559 |
+
fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
|
560 |
+
module_weight = fused_weight - fused_lora
|
561 |
+
return module_weight.to(dtype=w_dtype, device=device)
|
562 |
+
|
563 |
+
|
564 |
+
@torch.inference_mode()
|
565 |
+
def apply_lora_weight_to_module(
|
566 |
+
module_weight: torch.Tensor,
|
567 |
+
lora_weights: dict,
|
568 |
+
rank: int = None,
|
569 |
+
lora_scale: float = 1.0,
|
570 |
+
):
|
571 |
+
w_dtype = module_weight.dtype
|
572 |
+
dtype = torch.float32
|
573 |
+
device = module_weight.device
|
574 |
+
|
575 |
+
fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
|
576 |
+
fused_weight = module_weight.to(dtype=dtype) + fused_lora
|
577 |
+
return fused_weight.to(dtype=w_dtype, device=device)
|
578 |
+
|
579 |
+
|
580 |
+
def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
|
581 |
+
check_if_starts_with_transformer = [
|
582 |
+
k for k in lora_weights.keys() if k.startswith("transformer.")
|
583 |
+
]
|
584 |
+
if len(check_if_starts_with_transformer) > 0:
|
585 |
+
lora_weights = convert_diffusers_to_flux_transformer_checkpoint(
|
586 |
+
lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
|
587 |
+
)
|
588 |
+
else:
|
589 |
+
lora_weights = convert_from_original_flux_checkpoint(lora_weights)
|
590 |
+
logger.info("LoRA weights loaded")
|
591 |
+
logger.debug("Extracting keys")
|
592 |
+
keys_without_ab = list(
|
593 |
+
set(
|
594 |
+
[
|
595 |
+
key.replace(".lora_A.weight", "")
|
596 |
+
.replace(".lora_B.weight", "")
|
597 |
+
.replace(".lora_A", "")
|
598 |
+
.replace(".lora_B", "")
|
599 |
+
.replace(".alpha", "")
|
600 |
+
for key in lora_weights.keys()
|
601 |
+
]
|
602 |
+
)
|
603 |
+
)
|
604 |
+
logger.debug("Keys extracted")
|
605 |
+
return keys_without_ab, lora_weights
|
606 |
+
|
607 |
+
|
608 |
+
def get_lora_weights(lora_path: str | StateDict):
|
609 |
+
if isinstance(lora_path, (dict, LoraWeights)):
|
610 |
+
return lora_path, True
|
611 |
+
else:
|
612 |
+
return load_file(lora_path, "cpu"), False
|
613 |
+
|
614 |
+
|
615 |
+
def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]):
|
616 |
+
dtype = linear.weight.dtype
|
617 |
+
weight_is_f8 = False
|
618 |
+
if isinstance(linear, F8Linear):
|
619 |
+
weight_is_f8 = True
|
620 |
+
weight = (
|
621 |
+
linear.float8_data.clone()
|
622 |
+
.detach()
|
623 |
+
.float()
|
624 |
+
.mul(linear.scale_reciprocal)
|
625 |
+
.to(linear.weight.device)
|
626 |
+
)
|
627 |
+
elif isinstance(linear, torch.nn.Linear):
|
628 |
+
weight = linear.weight.clone().detach().float()
|
629 |
+
elif isinstance(linear, CublasLinear) and CublasLinear != type(None):
|
630 |
+
weight = linear.weight.clone().detach().float()
|
631 |
+
return weight, weight_is_f8, dtype
|
632 |
+
|
633 |
+
|
634 |
+
@torch.inference_mode()
|
635 |
+
def apply_lora_to_model(
|
636 |
+
model: Flux,
|
637 |
+
lora_path: str | StateDict,
|
638 |
+
lora_scale: float = 1.0,
|
639 |
+
return_lora_resolved: bool = False,
|
640 |
+
) -> Flux:
|
641 |
+
has_guidance = model.params.guidance_embed
|
642 |
+
logger.info(f"Loading LoRA weights for {lora_path}")
|
643 |
+
lora_weights, already_loaded = get_lora_weights(lora_path)
|
644 |
+
|
645 |
+
if not already_loaded:
|
646 |
+
keys_without_ab, lora_weights = resolve_lora_state_dict(
|
647 |
+
lora_weights, has_guidance
|
648 |
+
)
|
649 |
+
elif isinstance(lora_weights, LoraWeights):
|
650 |
+
b_ = lora_weights
|
651 |
+
lora_weights = b_.weights
|
652 |
+
keys_without_ab = list(
|
653 |
+
set(
|
654 |
+
[
|
655 |
+
key.replace(".lora_A.weight", "")
|
656 |
+
.replace(".lora_B.weight", "")
|
657 |
+
.replace(".lora_A", "")
|
658 |
+
.replace(".lora_B", "")
|
659 |
+
.replace(".alpha", "")
|
660 |
+
for key in lora_weights.keys()
|
661 |
+
]
|
662 |
+
)
|
663 |
+
)
|
664 |
+
else:
|
665 |
+
lora_weights = lora_weights
|
666 |
+
keys_without_ab = list(
|
667 |
+
set(
|
668 |
+
[
|
669 |
+
key.replace(".lora_A.weight", "")
|
670 |
+
.replace(".lora_B.weight", "")
|
671 |
+
.replace(".lora_A", "")
|
672 |
+
.replace(".lora_B", "")
|
673 |
+
.replace(".alpha", "")
|
674 |
+
for key in lora_weights.keys()
|
675 |
+
]
|
676 |
+
)
|
677 |
+
)
|
678 |
+
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
679 |
+
module = get_module_for_key(key, model)
|
680 |
+
weight, is_f8, dtype = extract_weight_from_linear(module)
|
681 |
+
lora_sd = get_lora_for_key(key, lora_weights)
|
682 |
+
if lora_sd is None:
|
683 |
+
# Skipping LoRA application for this module
|
684 |
+
continue
|
685 |
+
weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
|
686 |
+
if is_f8:
|
687 |
+
module.set_weight_tensor(weight.type(dtype))
|
688 |
+
else:
|
689 |
+
module.weight.data = weight.type(dtype)
|
690 |
+
logger.success("Lora applied")
|
691 |
+
if return_lora_resolved:
|
692 |
+
return model, lora_weights
|
693 |
+
return model
|
694 |
+
|
695 |
+
|
696 |
+
def remove_lora_from_module(
|
697 |
+
model: Flux,
|
698 |
+
lora_path: str | StateDict,
|
699 |
+
lora_scale: float = 1.0,
|
700 |
+
):
|
701 |
+
has_guidance = model.params.guidance_embed
|
702 |
+
logger.info(f"Loading LoRA weights for {lora_path}")
|
703 |
+
lora_weights, already_loaded = get_lora_weights(lora_path)
|
704 |
+
|
705 |
+
if not already_loaded:
|
706 |
+
keys_without_ab, lora_weights = resolve_lora_state_dict(
|
707 |
+
lora_weights, has_guidance
|
708 |
+
)
|
709 |
+
elif isinstance(lora_weights, LoraWeights):
|
710 |
+
b_ = lora_weights
|
711 |
+
lora_weights = b_.weights
|
712 |
+
keys_without_ab = list(
|
713 |
+
set(
|
714 |
+
[
|
715 |
+
key.replace(".lora_A.weight", "")
|
716 |
+
.replace(".lora_B.weight", "")
|
717 |
+
.replace(".lora_A", "")
|
718 |
+
.replace(".lora_B", "")
|
719 |
+
.replace(".alpha", "")
|
720 |
+
for key in lora_weights.keys()
|
721 |
+
]
|
722 |
+
)
|
723 |
+
)
|
724 |
+
lora_scale = b_.scale
|
725 |
+
else:
|
726 |
+
lora_weights = lora_weights
|
727 |
+
keys_without_ab = list(
|
728 |
+
set(
|
729 |
+
[
|
730 |
+
key.replace(".lora_A.weight", "")
|
731 |
+
.replace(".lora_B.weight", "")
|
732 |
+
.replace(".lora_A", "")
|
733 |
+
.replace(".lora_B", "")
|
734 |
+
.replace(".alpha", "")
|
735 |
+
for key in lora_weights.keys()
|
736 |
+
]
|
737 |
+
)
|
738 |
+
)
|
739 |
+
|
740 |
+
for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
|
741 |
+
module = get_module_for_key(key, model)
|
742 |
+
weight, is_f8, dtype = extract_weight_from_linear(module)
|
743 |
+
lora_sd = get_lora_for_key(key, lora_weights)
|
744 |
+
if lora_sd is None:
|
745 |
+
# Skipping LoRA application for this module
|
746 |
+
continue
|
747 |
+
weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
|
748 |
+
if is_f8:
|
749 |
+
module.set_weight_tensor(weight.type(dtype))
|
750 |
+
else:
|
751 |
+
module.weight.data = weight.type(dtype)
|
752 |
+
logger.success("Lora unfused")
|
753 |
+
return model
|
main.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import uvicorn
|
3 |
+
from api import app
|
4 |
+
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser(description="Launch Flux API server")
|
8 |
+
parser.add_argument(
|
9 |
+
"-c",
|
10 |
+
"--config-path",
|
11 |
+
type=str,
|
12 |
+
help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"-p",
|
16 |
+
"--port",
|
17 |
+
type=int,
|
18 |
+
default=8088,
|
19 |
+
help="Port to run the server on",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"-H",
|
23 |
+
"--host",
|
24 |
+
type=str,
|
25 |
+
default="0.0.0.0",
|
26 |
+
help="Host to run the server on",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"-f", "--flow-model-path", type=str, help="Path to the flow model"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"-t", "--text-enc-path", type=str, help="Path to the text encoder"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"-a", "--autoencoder-path", type=str, help="Path to the autoencoder"
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"-m",
|
39 |
+
"--model-version",
|
40 |
+
type=str,
|
41 |
+
choices=["flux-dev", "flux-schnell"],
|
42 |
+
default="flux-dev",
|
43 |
+
help="Choose model version",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"-F",
|
47 |
+
"--flux-device",
|
48 |
+
type=str,
|
49 |
+
default="cuda:0",
|
50 |
+
help="Device to run the flow model on",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"-T",
|
54 |
+
"--text-enc-device",
|
55 |
+
type=str,
|
56 |
+
default="cuda:0",
|
57 |
+
help="Device to run the text encoder on",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"-A",
|
61 |
+
"--autoencoder-device",
|
62 |
+
type=str,
|
63 |
+
default="cuda:0",
|
64 |
+
help="Device to run the autoencoder on",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"-q",
|
68 |
+
"--num-to-quant",
|
69 |
+
type=int,
|
70 |
+
default=20,
|
71 |
+
help="Number of linear layers in flow transformer (the 'unet') to quantize",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"-C",
|
75 |
+
"--compile",
|
76 |
+
action="store_true",
|
77 |
+
default=False,
|
78 |
+
help="Compile the flow model with extra optimizations",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"-qT",
|
82 |
+
"--quant-text-enc",
|
83 |
+
type=str,
|
84 |
+
default="qfloat8",
|
85 |
+
choices=["qint4", "qfloat8", "qint2", "qint8", "bf16"],
|
86 |
+
help="Quantize the t5 text encoder to the given dtype, if bf16, will not quantize",
|
87 |
+
dest="quant_text_enc",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"-qA",
|
91 |
+
"--quant-ae",
|
92 |
+
action="store_true",
|
93 |
+
default=False,
|
94 |
+
help="Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16",
|
95 |
+
dest="quant_ae",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"-OF",
|
99 |
+
"--offload-flow",
|
100 |
+
action="store_true",
|
101 |
+
default=False,
|
102 |
+
dest="offload_flow",
|
103 |
+
help="Offload the flow model to the CPU when not being used to save memory",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"-OA",
|
107 |
+
"--no-offload-ae",
|
108 |
+
action="store_false",
|
109 |
+
default=True,
|
110 |
+
dest="offload_ae",
|
111 |
+
help="Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"-OT",
|
115 |
+
"--no-offload-text-enc",
|
116 |
+
action="store_false",
|
117 |
+
default=True,
|
118 |
+
dest="offload_text_enc",
|
119 |
+
help="Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"-PF",
|
123 |
+
"--prequantized-flow",
|
124 |
+
action="store_true",
|
125 |
+
default=False,
|
126 |
+
dest="prequantized_flow",
|
127 |
+
help="Load the flow model from a prequantized checkpoint "
|
128 |
+
+ "(requires loading the flow model, running a minimum of 24 steps, "
|
129 |
+
+ "and then saving the state_dict as a safetensors file), "
|
130 |
+
+ "which reduces the size of the checkpoint by about 50% & reduces startup time",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"-nqfm",
|
134 |
+
"--no-quantize-flow-modulation",
|
135 |
+
action="store_false",
|
136 |
+
default=True,
|
137 |
+
dest="quantize_modulation",
|
138 |
+
help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"-qfl",
|
142 |
+
"--quantize-flow-embedder-layers",
|
143 |
+
action="store_true",
|
144 |
+
default=False,
|
145 |
+
dest="quantize_flow_embedder_layers",
|
146 |
+
help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
|
147 |
+
)
|
148 |
+
return parser.parse_args()
|
149 |
+
|
150 |
+
|
151 |
+
def main():
|
152 |
+
args = parse_args()
|
153 |
+
|
154 |
+
# lazy loading so cli returns fast instead of waiting for torch to load modules
|
155 |
+
from flux_pipeline import FluxPipeline
|
156 |
+
from util import load_config, ModelVersion
|
157 |
+
|
158 |
+
if args.config_path:
|
159 |
+
app.state.model = FluxPipeline.load_pipeline_from_config_path(
|
160 |
+
args.config_path, flow_model_path=args.flow_model_path
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
model_version = (
|
164 |
+
ModelVersion.flux_dev
|
165 |
+
if args.model_version == "flux-dev"
|
166 |
+
else ModelVersion.flux_schnell
|
167 |
+
)
|
168 |
+
config = load_config(
|
169 |
+
model_version,
|
170 |
+
flux_path=args.flow_model_path,
|
171 |
+
flux_device=args.flux_device,
|
172 |
+
ae_path=args.autoencoder_path,
|
173 |
+
ae_device=args.autoencoder_device,
|
174 |
+
text_enc_path=args.text_enc_path,
|
175 |
+
text_enc_device=args.text_enc_device,
|
176 |
+
flow_dtype="float16",
|
177 |
+
text_enc_dtype="bfloat16",
|
178 |
+
ae_dtype="bfloat16",
|
179 |
+
num_to_quant=args.num_to_quant,
|
180 |
+
compile_extras=args.compile,
|
181 |
+
compile_blocks=args.compile,
|
182 |
+
quant_text_enc=(
|
183 |
+
None if args.quant_text_enc == "bf16" else args.quant_text_enc
|
184 |
+
),
|
185 |
+
quant_ae=args.quant_ae,
|
186 |
+
offload_flow=args.offload_flow,
|
187 |
+
offload_ae=args.offload_ae,
|
188 |
+
offload_text_enc=args.offload_text_enc,
|
189 |
+
prequantized_flow=args.prequantized_flow,
|
190 |
+
quantize_modulation=args.quantize_modulation,
|
191 |
+
quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
|
192 |
+
)
|
193 |
+
app.state.model = FluxPipeline.load_pipeline_from_config(config)
|
194 |
+
|
195 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
196 |
+
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
main()
|
modules/autoencoder.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import Tensor, nn
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
|
7 |
+
class AutoEncoderParams(BaseModel):
|
8 |
+
resolution: int
|
9 |
+
in_channels: int
|
10 |
+
ch: int
|
11 |
+
out_ch: int
|
12 |
+
ch_mult: list[int]
|
13 |
+
num_res_blocks: int
|
14 |
+
z_channels: int
|
15 |
+
scale_factor: float
|
16 |
+
shift_factor: float
|
17 |
+
|
18 |
+
|
19 |
+
def swish(x: Tensor) -> Tensor:
|
20 |
+
return x * torch.sigmoid(x)
|
21 |
+
|
22 |
+
|
23 |
+
class AttnBlock(nn.Module):
|
24 |
+
def __init__(self, in_channels: int):
|
25 |
+
super().__init__()
|
26 |
+
self.in_channels = in_channels
|
27 |
+
|
28 |
+
self.norm = nn.GroupNorm(
|
29 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
30 |
+
)
|
31 |
+
|
32 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
33 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
34 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
35 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
+
|
37 |
+
def attention(self, h_: Tensor) -> Tensor:
|
38 |
+
h_ = self.norm(h_)
|
39 |
+
q = self.q(h_)
|
40 |
+
k = self.k(h_)
|
41 |
+
v = self.v(h_)
|
42 |
+
|
43 |
+
b, c, h, w = q.shape
|
44 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
45 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
46 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
47 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
48 |
+
|
49 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
50 |
+
|
51 |
+
def forward(self, x: Tensor) -> Tensor:
|
52 |
+
return x + self.proj_out(self.attention(x))
|
53 |
+
|
54 |
+
|
55 |
+
class ResnetBlock(nn.Module):
|
56 |
+
def __init__(self, in_channels: int, out_channels: int):
|
57 |
+
super().__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
60 |
+
self.out_channels = out_channels
|
61 |
+
|
62 |
+
self.norm1 = nn.GroupNorm(
|
63 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
64 |
+
)
|
65 |
+
self.conv1 = nn.Conv2d(
|
66 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
67 |
+
)
|
68 |
+
self.norm2 = nn.GroupNorm(
|
69 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
70 |
+
)
|
71 |
+
self.conv2 = nn.Conv2d(
|
72 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
73 |
+
)
|
74 |
+
if self.in_channels != self.out_channels:
|
75 |
+
self.nin_shortcut = nn.Conv2d(
|
76 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
h = x
|
81 |
+
h = self.norm1(h)
|
82 |
+
h = swish(h)
|
83 |
+
h = self.conv1(h)
|
84 |
+
|
85 |
+
h = self.norm2(h)
|
86 |
+
h = swish(h)
|
87 |
+
h = self.conv2(h)
|
88 |
+
|
89 |
+
if self.in_channels != self.out_channels:
|
90 |
+
x = self.nin_shortcut(x)
|
91 |
+
|
92 |
+
return x + h
|
93 |
+
|
94 |
+
|
95 |
+
class Downsample(nn.Module):
|
96 |
+
def __init__(self, in_channels: int):
|
97 |
+
super().__init__()
|
98 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
99 |
+
self.conv = nn.Conv2d(
|
100 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
101 |
+
)
|
102 |
+
|
103 |
+
def forward(self, x: Tensor):
|
104 |
+
pad = (0, 1, 0, 1)
|
105 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
106 |
+
x = self.conv(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class Upsample(nn.Module):
|
111 |
+
def __init__(self, in_channels: int):
|
112 |
+
super().__init__()
|
113 |
+
self.conv = nn.Conv2d(
|
114 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x: Tensor):
|
118 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
119 |
+
x = self.conv(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class Encoder(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
resolution: int,
|
127 |
+
in_channels: int,
|
128 |
+
ch: int,
|
129 |
+
ch_mult: list[int],
|
130 |
+
num_res_blocks: int,
|
131 |
+
z_channels: int,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
self.ch = ch
|
135 |
+
self.num_resolutions = len(ch_mult)
|
136 |
+
self.num_res_blocks = num_res_blocks
|
137 |
+
self.resolution = resolution
|
138 |
+
self.in_channels = in_channels
|
139 |
+
# downsampling
|
140 |
+
self.conv_in = nn.Conv2d(
|
141 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
142 |
+
)
|
143 |
+
|
144 |
+
curr_res = resolution
|
145 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
146 |
+
self.in_ch_mult = in_ch_mult
|
147 |
+
self.down = nn.ModuleList()
|
148 |
+
block_in = self.ch
|
149 |
+
for i_level in range(self.num_resolutions):
|
150 |
+
block = nn.ModuleList()
|
151 |
+
attn = nn.ModuleList()
|
152 |
+
block_in = ch * in_ch_mult[i_level]
|
153 |
+
block_out = ch * ch_mult[i_level]
|
154 |
+
for _ in range(self.num_res_blocks):
|
155 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
156 |
+
block_in = block_out
|
157 |
+
down = nn.Module()
|
158 |
+
down.block = block
|
159 |
+
down.attn = attn
|
160 |
+
if i_level != self.num_resolutions - 1:
|
161 |
+
down.downsample = Downsample(block_in)
|
162 |
+
curr_res = curr_res // 2
|
163 |
+
self.down.append(down)
|
164 |
+
|
165 |
+
# middle
|
166 |
+
self.mid = nn.Module()
|
167 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
168 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
169 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
170 |
+
|
171 |
+
# end
|
172 |
+
self.norm_out = nn.GroupNorm(
|
173 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
174 |
+
)
|
175 |
+
self.conv_out = nn.Conv2d(
|
176 |
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
177 |
+
)
|
178 |
+
|
179 |
+
def forward(self, x: Tensor) -> Tensor:
|
180 |
+
# downsampling
|
181 |
+
hs = [self.conv_in(x)]
|
182 |
+
for i_level in range(self.num_resolutions):
|
183 |
+
for i_block in range(self.num_res_blocks):
|
184 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
185 |
+
if len(self.down[i_level].attn) > 0:
|
186 |
+
h = self.down[i_level].attn[i_block](h)
|
187 |
+
hs.append(h)
|
188 |
+
if i_level != self.num_resolutions - 1:
|
189 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
190 |
+
|
191 |
+
# middle
|
192 |
+
h = hs[-1]
|
193 |
+
h = self.mid.block_1(h)
|
194 |
+
h = self.mid.attn_1(h)
|
195 |
+
h = self.mid.block_2(h)
|
196 |
+
# end
|
197 |
+
h = self.norm_out(h)
|
198 |
+
h = swish(h)
|
199 |
+
h = self.conv_out(h)
|
200 |
+
return h
|
201 |
+
|
202 |
+
|
203 |
+
class Decoder(nn.Module):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
ch: int,
|
207 |
+
out_ch: int,
|
208 |
+
ch_mult: list[int],
|
209 |
+
num_res_blocks: int,
|
210 |
+
in_channels: int,
|
211 |
+
resolution: int,
|
212 |
+
z_channels: int,
|
213 |
+
):
|
214 |
+
super().__init__()
|
215 |
+
self.ch = ch
|
216 |
+
self.num_resolutions = len(ch_mult)
|
217 |
+
self.num_res_blocks = num_res_blocks
|
218 |
+
self.resolution = resolution
|
219 |
+
self.in_channels = in_channels
|
220 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
221 |
+
|
222 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
223 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
224 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
225 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
226 |
+
|
227 |
+
# z to block_in
|
228 |
+
self.conv_in = nn.Conv2d(
|
229 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
230 |
+
)
|
231 |
+
|
232 |
+
# middle
|
233 |
+
self.mid = nn.Module()
|
234 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
235 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
236 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
237 |
+
|
238 |
+
# upsampling
|
239 |
+
self.up = nn.ModuleList()
|
240 |
+
for i_level in reversed(range(self.num_resolutions)):
|
241 |
+
block = nn.ModuleList()
|
242 |
+
attn = nn.ModuleList()
|
243 |
+
block_out = ch * ch_mult[i_level]
|
244 |
+
for _ in range(self.num_res_blocks + 1):
|
245 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
246 |
+
block_in = block_out
|
247 |
+
up = nn.Module()
|
248 |
+
up.block = block
|
249 |
+
up.attn = attn
|
250 |
+
if i_level != 0:
|
251 |
+
up.upsample = Upsample(block_in)
|
252 |
+
curr_res = curr_res * 2
|
253 |
+
self.up.insert(0, up) # prepend to get consistent order
|
254 |
+
|
255 |
+
# end
|
256 |
+
self.norm_out = nn.GroupNorm(
|
257 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
258 |
+
)
|
259 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
260 |
+
|
261 |
+
def forward(self, z: Tensor) -> Tensor:
|
262 |
+
# z to block_in
|
263 |
+
h = self.conv_in(z)
|
264 |
+
|
265 |
+
# middle
|
266 |
+
h = self.mid.block_1(h)
|
267 |
+
h = self.mid.attn_1(h)
|
268 |
+
h = self.mid.block_2(h)
|
269 |
+
|
270 |
+
# upsampling
|
271 |
+
for i_level in reversed(range(self.num_resolutions)):
|
272 |
+
for i_block in range(self.num_res_blocks + 1):
|
273 |
+
h = self.up[i_level].block[i_block](h)
|
274 |
+
if len(self.up[i_level].attn) > 0:
|
275 |
+
h = self.up[i_level].attn[i_block](h)
|
276 |
+
if i_level != 0:
|
277 |
+
h = self.up[i_level].upsample(h)
|
278 |
+
|
279 |
+
# end
|
280 |
+
h = self.norm_out(h)
|
281 |
+
h = swish(h)
|
282 |
+
h = self.conv_out(h)
|
283 |
+
return h
|
284 |
+
|
285 |
+
|
286 |
+
class DiagonalGaussian(nn.Module):
|
287 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
288 |
+
super().__init__()
|
289 |
+
self.sample = sample
|
290 |
+
self.chunk_dim = chunk_dim
|
291 |
+
|
292 |
+
def forward(self, z: Tensor) -> Tensor:
|
293 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
294 |
+
if self.sample:
|
295 |
+
std = torch.exp(0.5 * logvar)
|
296 |
+
return mean + std * torch.randn_like(mean)
|
297 |
+
else:
|
298 |
+
return mean
|
299 |
+
|
300 |
+
|
301 |
+
class AutoEncoder(nn.Module):
|
302 |
+
def __init__(self, params: AutoEncoderParams):
|
303 |
+
super().__init__()
|
304 |
+
self.encoder = Encoder(
|
305 |
+
resolution=params.resolution,
|
306 |
+
in_channels=params.in_channels,
|
307 |
+
ch=params.ch,
|
308 |
+
ch_mult=params.ch_mult,
|
309 |
+
num_res_blocks=params.num_res_blocks,
|
310 |
+
z_channels=params.z_channels,
|
311 |
+
)
|
312 |
+
self.decoder = Decoder(
|
313 |
+
resolution=params.resolution,
|
314 |
+
in_channels=params.in_channels,
|
315 |
+
ch=params.ch,
|
316 |
+
out_ch=params.out_ch,
|
317 |
+
ch_mult=params.ch_mult,
|
318 |
+
num_res_blocks=params.num_res_blocks,
|
319 |
+
z_channels=params.z_channels,
|
320 |
+
)
|
321 |
+
self.reg = DiagonalGaussian()
|
322 |
+
|
323 |
+
self.scale_factor = params.scale_factor
|
324 |
+
self.shift_factor = params.shift_factor
|
325 |
+
|
326 |
+
def encode(self, x: Tensor) -> Tensor:
|
327 |
+
z = self.reg(self.encoder(x))
|
328 |
+
z = self.scale_factor * (z - self.shift_factor)
|
329 |
+
return z
|
330 |
+
|
331 |
+
def decode(self, z: Tensor) -> Tensor:
|
332 |
+
z = z / self.scale_factor + self.shift_factor
|
333 |
+
return self.decoder(z)
|
334 |
+
|
335 |
+
def forward(self, x: Tensor) -> Tensor:
|
336 |
+
return self.decode(self.encode(x))
|
modules/conditioner.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from transformers import (
|
6 |
+
CLIPTextModel,
|
7 |
+
CLIPTokenizer,
|
8 |
+
T5EncoderModel,
|
9 |
+
T5Tokenizer,
|
10 |
+
__version__,
|
11 |
+
)
|
12 |
+
from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig
|
13 |
+
|
14 |
+
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
|
15 |
+
|
16 |
+
|
17 |
+
def auto_quantization_config(
|
18 |
+
quantization_dtype: str,
|
19 |
+
) -> QuantoConfig | BitsAndBytesConfig:
|
20 |
+
if quantization_dtype == "qfloat8":
|
21 |
+
return QuantoConfig(weights="float8")
|
22 |
+
elif quantization_dtype == "qint4":
|
23 |
+
return BitsAndBytesConfig(
|
24 |
+
load_in_4bit=True,
|
25 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
26 |
+
bnb_4bit_quant_type="nf4",
|
27 |
+
)
|
28 |
+
elif quantization_dtype == "qint8":
|
29 |
+
return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
|
30 |
+
elif quantization_dtype == "qint2":
|
31 |
+
return QuantoConfig(weights="int2")
|
32 |
+
elif quantization_dtype is None or quantization_dtype == "bfloat16":
|
33 |
+
return None
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
|
36 |
+
|
37 |
+
|
38 |
+
class HFEmbedder(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
version: str,
|
42 |
+
max_length: int,
|
43 |
+
device: torch.device | int,
|
44 |
+
quantization_dtype: str | None = None,
|
45 |
+
offloading_device: torch.device | int | None = torch.device("cpu"),
|
46 |
+
is_clip: bool = False,
|
47 |
+
**hf_kwargs,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.offloading_device = (
|
51 |
+
offloading_device
|
52 |
+
if isinstance(offloading_device, torch.device)
|
53 |
+
else torch.device(offloading_device)
|
54 |
+
)
|
55 |
+
self.device = (
|
56 |
+
device if isinstance(device, torch.device) else torch.device(device)
|
57 |
+
)
|
58 |
+
self.is_clip = version.startswith("openai") or is_clip
|
59 |
+
self.max_length = max_length
|
60 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
61 |
+
|
62 |
+
auto_quant_config = (
|
63 |
+
auto_quantization_config(quantization_dtype)
|
64 |
+
if quantization_dtype is not None
|
65 |
+
and quantization_dtype != "bfloat16"
|
66 |
+
and quantization_dtype != "float16"
|
67 |
+
else None
|
68 |
+
)
|
69 |
+
|
70 |
+
# BNB will move to cuda:0 by default if not specified
|
71 |
+
if isinstance(auto_quant_config, BitsAndBytesConfig):
|
72 |
+
hf_kwargs["device_map"] = {"": self.device.index}
|
73 |
+
if auto_quant_config is not None:
|
74 |
+
hf_kwargs["quantization_config"] = auto_quant_config
|
75 |
+
|
76 |
+
if self.is_clip:
|
77 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
78 |
+
version, max_length=max_length
|
79 |
+
)
|
80 |
+
|
81 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
82 |
+
version,
|
83 |
+
**hf_kwargs,
|
84 |
+
)
|
85 |
+
|
86 |
+
else:
|
87 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
88 |
+
version, max_length=max_length
|
89 |
+
)
|
90 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
91 |
+
version,
|
92 |
+
**hf_kwargs,
|
93 |
+
)
|
94 |
+
|
95 |
+
def offload(self):
|
96 |
+
self.hf_module.to(device=self.offloading_device)
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
def cuda(self):
|
100 |
+
self.hf_module.to(device=self.device)
|
101 |
+
|
102 |
+
def forward(self, text: list[str]) -> Tensor:
|
103 |
+
batch_encoding = self.tokenizer(
|
104 |
+
text,
|
105 |
+
truncation=True,
|
106 |
+
max_length=self.max_length,
|
107 |
+
return_length=False,
|
108 |
+
return_overflowing_tokens=False,
|
109 |
+
padding="max_length",
|
110 |
+
return_tensors="pt",
|
111 |
+
)
|
112 |
+
outputs = self.hf_module(
|
113 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
114 |
+
attention_mask=None,
|
115 |
+
output_hidden_states=False,
|
116 |
+
)
|
117 |
+
return outputs[self.output_key]
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
model = HFEmbedder(
|
122 |
+
"city96/t5-v1_1-xxl-encoder-bf16",
|
123 |
+
max_length=512,
|
124 |
+
device=0,
|
125 |
+
quantization_dtype="qfloat8",
|
126 |
+
)
|
127 |
+
o = model(["hello"])
|
128 |
+
print(o)
|
modules/flux_model.py
ADDED
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import namedtuple
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from lora_loading import LoraWeights
|
10 |
+
from util import ModelSpec
|
11 |
+
DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
|
12 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
+
torch.backends.cudnn.allow_tf32 = True
|
14 |
+
torch.backends.cudnn.benchmark = True
|
15 |
+
torch.backends.cudnn.benchmark_limit = 20
|
16 |
+
torch.set_float32_matmul_precision("high")
|
17 |
+
import math
|
18 |
+
|
19 |
+
from pydantic import BaseModel
|
20 |
+
from torch import Tensor, nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
|
23 |
+
|
24 |
+
class FluxParams(BaseModel):
|
25 |
+
in_channels: int
|
26 |
+
vec_in_dim: int
|
27 |
+
context_in_dim: int
|
28 |
+
hidden_size: int
|
29 |
+
mlp_ratio: float
|
30 |
+
num_heads: int
|
31 |
+
depth: int
|
32 |
+
depth_single_blocks: int
|
33 |
+
axes_dim: list[int]
|
34 |
+
theta: int
|
35 |
+
qkv_bias: bool
|
36 |
+
guidance_embed: bool
|
37 |
+
|
38 |
+
|
39 |
+
# attention is always same shape each time it's called per H*W, so compile with fullgraph
|
40 |
+
# @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
|
41 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
42 |
+
q, k = apply_rope(q, k, pe)
|
43 |
+
x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
|
44 |
+
x = x.reshape(*x.shape[:-2], -1)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
# @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
|
49 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
50 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
51 |
+
omega = 1.0 / (theta**scale)
|
52 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
53 |
+
out = torch.stack(
|
54 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
55 |
+
)
|
56 |
+
out = out.reshape(*out.shape[:-1], 2, 2)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
61 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
|
62 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
|
63 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
64 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
65 |
+
return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
|
66 |
+
|
67 |
+
|
68 |
+
class EmbedND(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
dim: int,
|
72 |
+
theta: int,
|
73 |
+
axes_dim: list[int],
|
74 |
+
dtype: torch.dtype = torch.bfloat16,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.dim = dim
|
78 |
+
self.theta = theta
|
79 |
+
self.axes_dim = axes_dim
|
80 |
+
self.dtype = dtype
|
81 |
+
|
82 |
+
def forward(self, ids: Tensor) -> Tensor:
|
83 |
+
n_axes = ids.shape[-1]
|
84 |
+
emb = torch.cat(
|
85 |
+
[
|
86 |
+
rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
|
87 |
+
for i in range(n_axes)
|
88 |
+
],
|
89 |
+
dim=-3,
|
90 |
+
)
|
91 |
+
|
92 |
+
return emb.unsqueeze(1)
|
93 |
+
|
94 |
+
|
95 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
96 |
+
"""
|
97 |
+
Create sinusoidal timestep embeddings.
|
98 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
99 |
+
These may be fractional.
|
100 |
+
:param dim: the dimension of the output.
|
101 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
102 |
+
:return: an (N, D) Tensor of positional embeddings.
|
103 |
+
"""
|
104 |
+
t = time_factor * t
|
105 |
+
half = dim // 2
|
106 |
+
freqs = torch.exp(
|
107 |
+
-math.log(max_period)
|
108 |
+
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
109 |
+
/ half
|
110 |
+
)
|
111 |
+
|
112 |
+
args = t[:, None].float() * freqs[None]
|
113 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
114 |
+
if dim % 2:
|
115 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
116 |
+
return embedding
|
117 |
+
|
118 |
+
|
119 |
+
class MLPEmbedder(nn.Module):
|
120 |
+
def __init__(
|
121 |
+
self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False
|
122 |
+
):
|
123 |
+
from float8_quantize import F8Linear
|
124 |
+
|
125 |
+
super().__init__()
|
126 |
+
self.in_layer = (
|
127 |
+
nn.Linear(in_dim, hidden_dim, bias=True)
|
128 |
+
if not prequantized
|
129 |
+
else (
|
130 |
+
F8Linear(
|
131 |
+
in_features=in_dim,
|
132 |
+
out_features=hidden_dim,
|
133 |
+
bias=True,
|
134 |
+
)
|
135 |
+
if quantized
|
136 |
+
else nn.Linear(in_dim, hidden_dim, bias=True)
|
137 |
+
)
|
138 |
+
)
|
139 |
+
self.silu = nn.SiLU()
|
140 |
+
self.out_layer = (
|
141 |
+
nn.Linear(hidden_dim, hidden_dim, bias=True)
|
142 |
+
if not prequantized
|
143 |
+
else (
|
144 |
+
F8Linear(
|
145 |
+
in_features=hidden_dim,
|
146 |
+
out_features=hidden_dim,
|
147 |
+
bias=True,
|
148 |
+
)
|
149 |
+
if quantized
|
150 |
+
else nn.Linear(hidden_dim, hidden_dim, bias=True)
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x: Tensor) -> Tensor:
|
155 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
156 |
+
|
157 |
+
|
158 |
+
class RMSNorm(torch.nn.Module):
|
159 |
+
def __init__(self, dim: int):
|
160 |
+
super().__init__()
|
161 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
162 |
+
|
163 |
+
def forward(self, x: Tensor):
|
164 |
+
return F.rms_norm(x.float(), self.scale.shape, self.scale, eps=1e-6).to(x)
|
165 |
+
|
166 |
+
|
167 |
+
class QKNorm(torch.nn.Module):
|
168 |
+
def __init__(self, dim: int):
|
169 |
+
super().__init__()
|
170 |
+
self.query_norm = RMSNorm(dim)
|
171 |
+
self.key_norm = RMSNorm(dim)
|
172 |
+
|
173 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
174 |
+
q = self.query_norm(q)
|
175 |
+
k = self.key_norm(k)
|
176 |
+
return q, k
|
177 |
+
|
178 |
+
|
179 |
+
class SelfAttention(nn.Module):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
dim: int,
|
183 |
+
num_heads: int = 8,
|
184 |
+
qkv_bias: bool = False,
|
185 |
+
prequantized: bool = False,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
from float8_quantize import F8Linear
|
189 |
+
|
190 |
+
self.num_heads = num_heads
|
191 |
+
head_dim = dim // num_heads
|
192 |
+
|
193 |
+
self.qkv = (
|
194 |
+
nn.Linear(dim, dim * 3, bias=qkv_bias)
|
195 |
+
if not prequantized
|
196 |
+
else F8Linear(
|
197 |
+
in_features=dim,
|
198 |
+
out_features=dim * 3,
|
199 |
+
bias=qkv_bias,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
self.norm = QKNorm(head_dim)
|
203 |
+
self.proj = (
|
204 |
+
nn.Linear(dim, dim)
|
205 |
+
if not prequantized
|
206 |
+
else F8Linear(
|
207 |
+
in_features=dim,
|
208 |
+
out_features=dim,
|
209 |
+
bias=True,
|
210 |
+
)
|
211 |
+
)
|
212 |
+
self.K = 3
|
213 |
+
self.H = self.num_heads
|
214 |
+
self.KH = self.K * self.H
|
215 |
+
|
216 |
+
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
217 |
+
B, L, D = x.shape
|
218 |
+
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
219 |
+
return q, k, v
|
220 |
+
|
221 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
222 |
+
qkv = self.qkv(x)
|
223 |
+
q, k, v = self.rearrange_for_norm(qkv)
|
224 |
+
q, k = self.norm(q, k, v)
|
225 |
+
x = attention(q, k, v, pe=pe)
|
226 |
+
x = self.proj(x)
|
227 |
+
return x
|
228 |
+
|
229 |
+
|
230 |
+
ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
|
231 |
+
|
232 |
+
|
233 |
+
class Modulation(nn.Module):
|
234 |
+
def __init__(self, dim: int, double: bool, quantized_modulation: bool = False):
|
235 |
+
super().__init__()
|
236 |
+
from float8_quantize import F8Linear
|
237 |
+
|
238 |
+
self.is_double = double
|
239 |
+
self.multiplier = 6 if double else 3
|
240 |
+
self.lin = (
|
241 |
+
nn.Linear(dim, self.multiplier * dim, bias=True)
|
242 |
+
if not quantized_modulation
|
243 |
+
else F8Linear(
|
244 |
+
in_features=dim,
|
245 |
+
out_features=self.multiplier * dim,
|
246 |
+
bias=True,
|
247 |
+
)
|
248 |
+
)
|
249 |
+
self.act = nn.SiLU()
|
250 |
+
|
251 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
252 |
+
out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
253 |
+
|
254 |
+
return (
|
255 |
+
ModulationOut(*out[:3]),
|
256 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
257 |
+
)
|
258 |
+
|
259 |
+
|
260 |
+
class DoubleStreamBlock(nn.Module):
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
hidden_size: int,
|
264 |
+
num_heads: int,
|
265 |
+
mlp_ratio: float,
|
266 |
+
qkv_bias: bool = False,
|
267 |
+
dtype: torch.dtype = torch.float16,
|
268 |
+
quantized_modulation: bool = False,
|
269 |
+
prequantized: bool = False,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
from float8_quantize import F8Linear
|
273 |
+
|
274 |
+
self.dtype = dtype
|
275 |
+
|
276 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
277 |
+
self.num_heads = num_heads
|
278 |
+
self.hidden_size = hidden_size
|
279 |
+
self.img_mod = Modulation(
|
280 |
+
hidden_size, double=True, quantized_modulation=quantized_modulation
|
281 |
+
)
|
282 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
283 |
+
self.img_attn = SelfAttention(
|
284 |
+
dim=hidden_size,
|
285 |
+
num_heads=num_heads,
|
286 |
+
qkv_bias=qkv_bias,
|
287 |
+
prequantized=prequantized,
|
288 |
+
)
|
289 |
+
|
290 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
291 |
+
self.img_mlp = nn.Sequential(
|
292 |
+
(
|
293 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
|
294 |
+
if not prequantized
|
295 |
+
else F8Linear(
|
296 |
+
in_features=hidden_size,
|
297 |
+
out_features=mlp_hidden_dim,
|
298 |
+
bias=True,
|
299 |
+
)
|
300 |
+
),
|
301 |
+
nn.GELU(approximate="tanh"),
|
302 |
+
(
|
303 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
|
304 |
+
if not prequantized
|
305 |
+
else F8Linear(
|
306 |
+
in_features=mlp_hidden_dim,
|
307 |
+
out_features=hidden_size,
|
308 |
+
bias=True,
|
309 |
+
)
|
310 |
+
),
|
311 |
+
)
|
312 |
+
|
313 |
+
self.txt_mod = Modulation(
|
314 |
+
hidden_size, double=True, quantized_modulation=quantized_modulation
|
315 |
+
)
|
316 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
317 |
+
self.txt_attn = SelfAttention(
|
318 |
+
dim=hidden_size,
|
319 |
+
num_heads=num_heads,
|
320 |
+
qkv_bias=qkv_bias,
|
321 |
+
prequantized=prequantized,
|
322 |
+
)
|
323 |
+
|
324 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
325 |
+
self.txt_mlp = nn.Sequential(
|
326 |
+
(
|
327 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
|
328 |
+
if not prequantized
|
329 |
+
else F8Linear(
|
330 |
+
in_features=hidden_size,
|
331 |
+
out_features=mlp_hidden_dim,
|
332 |
+
bias=True,
|
333 |
+
)
|
334 |
+
),
|
335 |
+
nn.GELU(approximate="tanh"),
|
336 |
+
(
|
337 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
|
338 |
+
if not prequantized
|
339 |
+
else F8Linear(
|
340 |
+
in_features=mlp_hidden_dim,
|
341 |
+
out_features=hidden_size,
|
342 |
+
bias=True,
|
343 |
+
)
|
344 |
+
),
|
345 |
+
)
|
346 |
+
self.K = 3
|
347 |
+
self.H = self.num_heads
|
348 |
+
self.KH = self.K * self.H
|
349 |
+
self.do_clamp = dtype == torch.float16
|
350 |
+
|
351 |
+
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
352 |
+
B, L, D = x.shape
|
353 |
+
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
354 |
+
return q, k, v
|
355 |
+
|
356 |
+
def forward(
|
357 |
+
self,
|
358 |
+
img: Tensor,
|
359 |
+
txt: Tensor,
|
360 |
+
vec: Tensor,
|
361 |
+
pe: Tensor,
|
362 |
+
) -> tuple[Tensor, Tensor]:
|
363 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
364 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
365 |
+
|
366 |
+
# prepare image for attention
|
367 |
+
img_modulated = self.img_norm1(img)
|
368 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
369 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
370 |
+
img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
|
371 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
372 |
+
|
373 |
+
# prepare txt for attention
|
374 |
+
txt_modulated = self.txt_norm1(txt)
|
375 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
376 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
377 |
+
txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
|
378 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
379 |
+
|
380 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
381 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
382 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
383 |
+
|
384 |
+
attn = attention(q, k, v, pe=pe)
|
385 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
386 |
+
# calculate the img bloks
|
387 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
388 |
+
img = img + img_mod2.gate * self.img_mlp(
|
389 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
390 |
+
)
|
391 |
+
|
392 |
+
# calculate the txt bloks
|
393 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
394 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
395 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
396 |
+
)
|
397 |
+
if self.do_clamp:
|
398 |
+
img = img.clamp(min=-32000, max=32000)
|
399 |
+
txt = txt.clamp(min=-32000, max=32000)
|
400 |
+
return img, txt
|
401 |
+
|
402 |
+
|
403 |
+
class SingleStreamBlock(nn.Module):
|
404 |
+
"""
|
405 |
+
A DiT block with parallel linear layers as described in
|
406 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(
|
410 |
+
self,
|
411 |
+
hidden_size: int,
|
412 |
+
num_heads: int,
|
413 |
+
mlp_ratio: float = 4.0,
|
414 |
+
qk_scale: float | None = None,
|
415 |
+
dtype: torch.dtype = torch.float16,
|
416 |
+
quantized_modulation: bool = False,
|
417 |
+
prequantized: bool = False,
|
418 |
+
):
|
419 |
+
super().__init__()
|
420 |
+
from float8_quantize import F8Linear
|
421 |
+
|
422 |
+
self.dtype = dtype
|
423 |
+
self.hidden_dim = hidden_size
|
424 |
+
self.num_heads = num_heads
|
425 |
+
head_dim = hidden_size // num_heads
|
426 |
+
self.scale = qk_scale or head_dim**-0.5
|
427 |
+
|
428 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
429 |
+
# qkv and mlp_in
|
430 |
+
self.linear1 = (
|
431 |
+
nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
432 |
+
if not prequantized
|
433 |
+
else F8Linear(
|
434 |
+
in_features=hidden_size,
|
435 |
+
out_features=hidden_size * 3 + self.mlp_hidden_dim,
|
436 |
+
bias=True,
|
437 |
+
)
|
438 |
+
)
|
439 |
+
# proj and mlp_out
|
440 |
+
self.linear2 = (
|
441 |
+
nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
442 |
+
if not prequantized
|
443 |
+
else F8Linear(
|
444 |
+
in_features=hidden_size + self.mlp_hidden_dim,
|
445 |
+
out_features=hidden_size,
|
446 |
+
bias=True,
|
447 |
+
)
|
448 |
+
)
|
449 |
+
|
450 |
+
self.norm = QKNorm(head_dim)
|
451 |
+
|
452 |
+
self.hidden_size = hidden_size
|
453 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
454 |
+
|
455 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
456 |
+
self.modulation = Modulation(
|
457 |
+
hidden_size,
|
458 |
+
double=False,
|
459 |
+
quantized_modulation=quantized_modulation and prequantized,
|
460 |
+
)
|
461 |
+
|
462 |
+
self.K = 3
|
463 |
+
self.H = self.num_heads
|
464 |
+
self.KH = self.K * self.H
|
465 |
+
self.do_clamp = dtype == torch.float16
|
466 |
+
|
467 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
468 |
+
mod = self.modulation(vec)[0]
|
469 |
+
pre_norm = self.pre_norm(x)
|
470 |
+
x_mod = (1 + mod.scale) * pre_norm + mod.shift
|
471 |
+
qkv, mlp = torch.split(
|
472 |
+
self.linear1(x_mod),
|
473 |
+
[3 * self.hidden_size, self.mlp_hidden_dim],
|
474 |
+
dim=-1,
|
475 |
+
)
|
476 |
+
B, L, D = qkv.shape
|
477 |
+
q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
478 |
+
q, k = self.norm(q, k, v)
|
479 |
+
attn = attention(q, k, v, pe=pe)
|
480 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
481 |
+
if self.do_clamp:
|
482 |
+
out = (x + mod.gate * output).clamp(min=-32000, max=32000)
|
483 |
+
else:
|
484 |
+
out = x + mod.gate * output
|
485 |
+
return out
|
486 |
+
|
487 |
+
|
488 |
+
class LastLayer(nn.Module):
|
489 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
490 |
+
super().__init__()
|
491 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
492 |
+
self.linear = nn.Linear(
|
493 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
494 |
+
)
|
495 |
+
self.adaLN_modulation = nn.Sequential(
|
496 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
497 |
+
)
|
498 |
+
|
499 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
500 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
501 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
502 |
+
x = self.linear(x)
|
503 |
+
return x
|
504 |
+
|
505 |
+
|
506 |
+
class Flux(nn.Module):
|
507 |
+
"""
|
508 |
+
Transformer model for flow matching on sequences.
|
509 |
+
"""
|
510 |
+
|
511 |
+
def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16):
|
512 |
+
super().__init__()
|
513 |
+
|
514 |
+
self.dtype = dtype
|
515 |
+
self.params = config.params
|
516 |
+
self.in_channels = config.params.in_channels
|
517 |
+
self.out_channels = self.in_channels
|
518 |
+
self.loras: List[LoraWeights] = []
|
519 |
+
prequantized_flow = config.prequantized_flow
|
520 |
+
quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
|
521 |
+
quantized_modulation = config.quantize_modulation and prequantized_flow
|
522 |
+
from float8_quantize import F8Linear
|
523 |
+
|
524 |
+
if config.params.hidden_size % config.params.num_heads != 0:
|
525 |
+
raise ValueError(
|
526 |
+
f"Hidden size {config.params.hidden_size} must be divisible by num_heads {config.params.num_heads}"
|
527 |
+
)
|
528 |
+
pe_dim = config.params.hidden_size // config.params.num_heads
|
529 |
+
if sum(config.params.axes_dim) != pe_dim:
|
530 |
+
raise ValueError(
|
531 |
+
f"Got {config.params.axes_dim} but expected positional dim {pe_dim}"
|
532 |
+
)
|
533 |
+
self.hidden_size = config.params.hidden_size
|
534 |
+
self.num_heads = config.params.num_heads
|
535 |
+
self.pe_embedder = EmbedND(
|
536 |
+
dim=pe_dim,
|
537 |
+
theta=config.params.theta,
|
538 |
+
axes_dim=config.params.axes_dim,
|
539 |
+
dtype=self.dtype,
|
540 |
+
)
|
541 |
+
self.img_in = (
|
542 |
+
nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
543 |
+
if not prequantized_flow
|
544 |
+
else (
|
545 |
+
F8Linear(
|
546 |
+
in_features=self.in_channels,
|
547 |
+
out_features=self.hidden_size,
|
548 |
+
bias=True,
|
549 |
+
)
|
550 |
+
if quantized_embedders
|
551 |
+
else nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
552 |
+
)
|
553 |
+
)
|
554 |
+
self.time_in = MLPEmbedder(
|
555 |
+
in_dim=256,
|
556 |
+
hidden_dim=self.hidden_size,
|
557 |
+
prequantized=prequantized_flow,
|
558 |
+
quantized=quantized_embedders,
|
559 |
+
)
|
560 |
+
self.vector_in = MLPEmbedder(
|
561 |
+
config.params.vec_in_dim,
|
562 |
+
self.hidden_size,
|
563 |
+
prequantized=prequantized_flow,
|
564 |
+
quantized=quantized_embedders,
|
565 |
+
)
|
566 |
+
self.guidance_in = (
|
567 |
+
MLPEmbedder(
|
568 |
+
in_dim=256,
|
569 |
+
hidden_dim=self.hidden_size,
|
570 |
+
prequantized=prequantized_flow,
|
571 |
+
quantized=quantized_embedders,
|
572 |
+
)
|
573 |
+
if config.params.guidance_embed
|
574 |
+
else nn.Identity()
|
575 |
+
)
|
576 |
+
self.txt_in = (
|
577 |
+
nn.Linear(config.params.context_in_dim, self.hidden_size)
|
578 |
+
if not quantized_embedders
|
579 |
+
else (
|
580 |
+
F8Linear(
|
581 |
+
in_features=config.params.context_in_dim,
|
582 |
+
out_features=self.hidden_size,
|
583 |
+
bias=True,
|
584 |
+
)
|
585 |
+
if quantized_embedders
|
586 |
+
else nn.Linear(config.params.context_in_dim, self.hidden_size)
|
587 |
+
)
|
588 |
+
)
|
589 |
+
|
590 |
+
self.double_blocks = nn.ModuleList(
|
591 |
+
[
|
592 |
+
DoubleStreamBlock(
|
593 |
+
self.hidden_size,
|
594 |
+
self.num_heads,
|
595 |
+
mlp_ratio=config.params.mlp_ratio,
|
596 |
+
qkv_bias=config.params.qkv_bias,
|
597 |
+
dtype=self.dtype,
|
598 |
+
quantized_modulation=quantized_modulation,
|
599 |
+
prequantized=prequantized_flow,
|
600 |
+
)
|
601 |
+
for _ in range(config.params.depth)
|
602 |
+
]
|
603 |
+
)
|
604 |
+
|
605 |
+
self.single_blocks = nn.ModuleList(
|
606 |
+
[
|
607 |
+
SingleStreamBlock(
|
608 |
+
self.hidden_size,
|
609 |
+
self.num_heads,
|
610 |
+
mlp_ratio=config.params.mlp_ratio,
|
611 |
+
dtype=self.dtype,
|
612 |
+
quantized_modulation=quantized_modulation,
|
613 |
+
prequantized=prequantized_flow,
|
614 |
+
)
|
615 |
+
for _ in range(config.params.depth_single_blocks)
|
616 |
+
]
|
617 |
+
)
|
618 |
+
|
619 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
620 |
+
|
621 |
+
def get_lora(self, identifier: str):
|
622 |
+
for lora in self.loras:
|
623 |
+
if lora.path == identifier or lora.name == identifier:
|
624 |
+
return lora
|
625 |
+
|
626 |
+
def has_lora(self, identifier: str):
|
627 |
+
for lora in self.loras:
|
628 |
+
if lora.path == identifier or lora.name == identifier:
|
629 |
+
return True
|
630 |
+
|
631 |
+
def load_lora(self, path: str, scale: float, name: str = None):
|
632 |
+
from lora_loading import (
|
633 |
+
LoraWeights,
|
634 |
+
apply_lora_to_model,
|
635 |
+
remove_lora_from_module,
|
636 |
+
)
|
637 |
+
|
638 |
+
if self.has_lora(path):
|
639 |
+
lora = self.get_lora(path)
|
640 |
+
if lora.scale == scale:
|
641 |
+
logger.warning(
|
642 |
+
f"Lora {lora.name} already loaded with same scale - ignoring!"
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
remove_lora_from_module(self, lora, lora.scale)
|
646 |
+
apply_lora_to_model(self, lora, scale)
|
647 |
+
for idx, lora_ in enumerate(self.loras):
|
648 |
+
if lora_.path == lora.path:
|
649 |
+
self.loras[idx].scale = scale
|
650 |
+
break
|
651 |
+
else:
|
652 |
+
_, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True)
|
653 |
+
self.loras.append(LoraWeights(lora, path, name, scale))
|
654 |
+
|
655 |
+
def unload_lora(self, path_or_identifier: str):
|
656 |
+
from lora_loading import remove_lora_from_module
|
657 |
+
|
658 |
+
removed = False
|
659 |
+
for idx, lora_ in enumerate(list(self.loras)):
|
660 |
+
if lora_.path == path_or_identifier or lora_.name == path_or_identifier:
|
661 |
+
remove_lora_from_module(self, lora_.weights, lora_.scale)
|
662 |
+
self.loras.pop(idx)
|
663 |
+
removed = True
|
664 |
+
break
|
665 |
+
if not removed:
|
666 |
+
logger.warning(
|
667 |
+
f"Couldn't remove lora {path_or_identifier} as it wasn't found fused to the model!"
|
668 |
+
)
|
669 |
+
else:
|
670 |
+
logger.info("Successfully removed lora from module.")
|
671 |
+
|
672 |
+
def forward(
|
673 |
+
self,
|
674 |
+
img: Tensor,
|
675 |
+
img_ids: Tensor,
|
676 |
+
txt: Tensor,
|
677 |
+
txt_ids: Tensor,
|
678 |
+
timesteps: Tensor,
|
679 |
+
y: Tensor,
|
680 |
+
guidance: Tensor | None = None,
|
681 |
+
) -> Tensor:
|
682 |
+
if img.ndim != 3 or txt.ndim != 3:
|
683 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
684 |
+
|
685 |
+
# running on sequences img
|
686 |
+
img = self.img_in(img)
|
687 |
+
vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
|
688 |
+
|
689 |
+
if self.params.guidance_embed:
|
690 |
+
if guidance is None:
|
691 |
+
raise ValueError(
|
692 |
+
"Didn't get guidance strength for guidance distilled model."
|
693 |
+
)
|
694 |
+
vec = vec + self.guidance_in(
|
695 |
+
timestep_embedding(guidance, 256).type(self.dtype)
|
696 |
+
)
|
697 |
+
vec = vec + self.vector_in(y)
|
698 |
+
|
699 |
+
txt = self.txt_in(txt)
|
700 |
+
|
701 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
702 |
+
pe = self.pe_embedder(ids)
|
703 |
+
|
704 |
+
# double stream blocks
|
705 |
+
for block in self.double_blocks:
|
706 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
707 |
+
|
708 |
+
img = torch.cat((txt, img), 1)
|
709 |
+
|
710 |
+
# single stream blocks
|
711 |
+
for block in self.single_blocks:
|
712 |
+
img = block(img, vec=vec, pe=pe)
|
713 |
+
|
714 |
+
img = img[:, txt.shape[1] :, ...]
|
715 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
716 |
+
return img
|
717 |
+
|
718 |
+
@classmethod
|
719 |
+
def from_pretrained(
|
720 |
+
cls: "Flux", path: str, dtype: torch.dtype = torch.float16
|
721 |
+
) -> "Flux":
|
722 |
+
from safetensors.torch import load_file
|
723 |
+
|
724 |
+
from util import load_config_from_path
|
725 |
+
|
726 |
+
config = load_config_from_path(path)
|
727 |
+
with torch.device("meta"):
|
728 |
+
klass = cls(config=config, dtype=dtype)
|
729 |
+
if not config.prequantized_flow:
|
730 |
+
klass.type(dtype)
|
731 |
+
|
732 |
+
ckpt = load_file(config.ckpt_path, device="cpu")
|
733 |
+
klass.load_state_dict(ckpt, assign=True)
|
734 |
+
return klass.to("cpu")
|
photo.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8476d6d93124c5265bb9ff0b393600b7ca26b3a566822c036cd9f59141065a9b
|
3 |
+
size 174924704
|
start.py
ADDED
File without changes
|
util.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Literal, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from modules.autoencoder import AutoEncoder, AutoEncoderParams
|
7 |
+
from modules.conditioner import HFEmbedder
|
8 |
+
from modules.flux_model import Flux, FluxParams
|
9 |
+
from safetensors.torch import load_file as load_sft
|
10 |
+
|
11 |
+
try:
|
12 |
+
from enum import StrEnum
|
13 |
+
except:
|
14 |
+
from enum import Enum
|
15 |
+
|
16 |
+
class StrEnum(str, Enum):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
from pydantic import BaseModel, ConfigDict
|
21 |
+
from loguru import logger
|
22 |
+
|
23 |
+
|
24 |
+
class ModelVersion(StrEnum):
|
25 |
+
flux_dev = "flux-dev"
|
26 |
+
flux_schnell = "flux-schnell"
|
27 |
+
|
28 |
+
|
29 |
+
class QuantizationDtype(StrEnum):
|
30 |
+
qfloat8 = "qfloat8"
|
31 |
+
qint2 = "qint2"
|
32 |
+
qint4 = "qint4"
|
33 |
+
qint8 = "qint8"
|
34 |
+
bfloat16 = "bfloat16"
|
35 |
+
float16 = "float16"
|
36 |
+
|
37 |
+
|
38 |
+
class ModelSpec(BaseModel):
|
39 |
+
version: ModelVersion
|
40 |
+
params: FluxParams
|
41 |
+
ae_params: AutoEncoderParams
|
42 |
+
ckpt_path: str | None
|
43 |
+
# Add option to pass in custom clip model
|
44 |
+
clip_path: str | None = "openai/clip-vit-large-patch14"
|
45 |
+
ae_path: str | None
|
46 |
+
repo_id: str | None
|
47 |
+
repo_flow: str | None
|
48 |
+
repo_ae: str | None
|
49 |
+
text_enc_max_length: int = 512
|
50 |
+
text_enc_path: str | None
|
51 |
+
text_enc_device: str | torch.device | None = "cuda:0"
|
52 |
+
ae_device: str | torch.device | None = "cuda:0"
|
53 |
+
flux_device: str | torch.device | None = "cuda:0"
|
54 |
+
flow_dtype: str = "float16"
|
55 |
+
ae_dtype: str = "bfloat16"
|
56 |
+
text_enc_dtype: str = "bfloat16"
|
57 |
+
# unused / deprecated
|
58 |
+
num_to_quant: Optional[int] = 20
|
59 |
+
quantize_extras: bool = False
|
60 |
+
compile_extras: bool = False
|
61 |
+
compile_blocks: bool = False
|
62 |
+
flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
|
63 |
+
text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
|
64 |
+
ae_quantization_dtype: Optional[QuantizationDtype] = None
|
65 |
+
clip_quantization_dtype: Optional[QuantizationDtype] = None
|
66 |
+
offload_text_encoder: bool = False
|
67 |
+
offload_vae: bool = False
|
68 |
+
offload_flow: bool = False
|
69 |
+
prequantized_flow: bool = False
|
70 |
+
|
71 |
+
# Improved precision via not quanitzing the modulation linear layers
|
72 |
+
quantize_modulation: bool = True
|
73 |
+
# Improved precision via not quanitzing the flow embedder layers
|
74 |
+
quantize_flow_embedder_layers: bool = False
|
75 |
+
|
76 |
+
model_config: ConfigDict = {
|
77 |
+
"arbitrary_types_allowed": True,
|
78 |
+
"use_enum_values": True,
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]:
|
83 |
+
flow = load_flow_model(config)
|
84 |
+
ae = load_autoencoder(config)
|
85 |
+
clip, t5 = load_text_encoders(config)
|
86 |
+
return flow, ae, clip, t5
|
87 |
+
|
88 |
+
|
89 |
+
def parse_device(device: str | torch.device | None) -> torch.device:
|
90 |
+
if isinstance(device, str):
|
91 |
+
return torch.device(device)
|
92 |
+
elif isinstance(device, torch.device):
|
93 |
+
return device
|
94 |
+
else:
|
95 |
+
return torch.device("cuda:0")
|
96 |
+
|
97 |
+
|
98 |
+
def into_dtype(dtype: str) -> torch.dtype:
|
99 |
+
if isinstance(dtype, torch.dtype):
|
100 |
+
return dtype
|
101 |
+
if dtype == "float16":
|
102 |
+
return torch.float16
|
103 |
+
elif dtype == "bfloat16":
|
104 |
+
return torch.bfloat16
|
105 |
+
elif dtype == "float32":
|
106 |
+
return torch.float32
|
107 |
+
else:
|
108 |
+
raise ValueError(f"Invalid dtype: {dtype}")
|
109 |
+
|
110 |
+
|
111 |
+
def into_device(device: str | torch.device | None) -> torch.device:
|
112 |
+
if isinstance(device, str):
|
113 |
+
return torch.device(device)
|
114 |
+
elif isinstance(device, torch.device):
|
115 |
+
return device
|
116 |
+
elif isinstance(device, int):
|
117 |
+
return torch.device(f"cuda:{device}")
|
118 |
+
else:
|
119 |
+
return torch.device("cuda:0")
|
120 |
+
|
121 |
+
|
122 |
+
def load_config(
|
123 |
+
name: ModelVersion = ModelVersion.flux_dev,
|
124 |
+
flux_path: str | None = None,
|
125 |
+
ae_path: str | None = None,
|
126 |
+
text_enc_path: str | None = None,
|
127 |
+
text_enc_device: str | torch.device | None = None,
|
128 |
+
ae_device: str | torch.device | None = None,
|
129 |
+
flux_device: str | torch.device | None = None,
|
130 |
+
flow_dtype: str = "float16",
|
131 |
+
ae_dtype: str = "bfloat16",
|
132 |
+
text_enc_dtype: str = "bfloat16",
|
133 |
+
num_to_quant: Optional[int] = 20,
|
134 |
+
compile_extras: bool = False,
|
135 |
+
compile_blocks: bool = False,
|
136 |
+
offload_text_enc: bool = False,
|
137 |
+
offload_ae: bool = False,
|
138 |
+
offload_flow: bool = False,
|
139 |
+
quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
|
140 |
+
quant_ae: bool = False,
|
141 |
+
prequantized_flow: bool = False,
|
142 |
+
quantize_modulation: bool = True,
|
143 |
+
quantize_flow_embedder_layers: bool = False,
|
144 |
+
) -> ModelSpec:
|
145 |
+
"""
|
146 |
+
Load a model configuration using the passed arguments.
|
147 |
+
"""
|
148 |
+
text_enc_device = str(parse_device(text_enc_device))
|
149 |
+
ae_device = str(parse_device(ae_device))
|
150 |
+
flux_device = str(parse_device(flux_device))
|
151 |
+
return ModelSpec(
|
152 |
+
version=name,
|
153 |
+
repo_id=(
|
154 |
+
"black-forest-labs/FLUX.1-dev"
|
155 |
+
if name == ModelVersion.flux_dev
|
156 |
+
else "black-forest-labs/FLUX.1-schnell"
|
157 |
+
),
|
158 |
+
repo_flow=(
|
159 |
+
"flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft"
|
160 |
+
),
|
161 |
+
repo_ae="ae.sft",
|
162 |
+
ckpt_path=flux_path,
|
163 |
+
params=FluxParams(
|
164 |
+
in_channels=64,
|
165 |
+
vec_in_dim=768,
|
166 |
+
context_in_dim=4096,
|
167 |
+
hidden_size=3072,
|
168 |
+
mlp_ratio=4.0,
|
169 |
+
num_heads=24,
|
170 |
+
depth=19,
|
171 |
+
depth_single_blocks=38,
|
172 |
+
axes_dim=[16, 56, 56],
|
173 |
+
theta=10_000,
|
174 |
+
qkv_bias=True,
|
175 |
+
guidance_embed=name == ModelVersion.flux_dev,
|
176 |
+
),
|
177 |
+
ae_path=ae_path,
|
178 |
+
ae_params=AutoEncoderParams(
|
179 |
+
resolution=256,
|
180 |
+
in_channels=3,
|
181 |
+
ch=128,
|
182 |
+
out_ch=3,
|
183 |
+
ch_mult=[1, 2, 4, 4],
|
184 |
+
num_res_blocks=2,
|
185 |
+
z_channels=16,
|
186 |
+
scale_factor=0.3611,
|
187 |
+
shift_factor=0.1159,
|
188 |
+
),
|
189 |
+
text_enc_path=text_enc_path,
|
190 |
+
text_enc_device=text_enc_device,
|
191 |
+
ae_device=ae_device,
|
192 |
+
flux_device=flux_device,
|
193 |
+
flow_dtype=flow_dtype,
|
194 |
+
ae_dtype=ae_dtype,
|
195 |
+
text_enc_dtype=text_enc_dtype,
|
196 |
+
text_enc_max_length=512 if name == ModelVersion.flux_dev else 256,
|
197 |
+
num_to_quant=num_to_quant,
|
198 |
+
compile_extras=compile_extras,
|
199 |
+
compile_blocks=compile_blocks,
|
200 |
+
offload_flow=offload_flow,
|
201 |
+
offload_text_encoder=offload_text_enc,
|
202 |
+
offload_vae=offload_ae,
|
203 |
+
text_enc_quantization_dtype={
|
204 |
+
"float8": QuantizationDtype.qfloat8,
|
205 |
+
"qint2": QuantizationDtype.qint2,
|
206 |
+
"qint4": QuantizationDtype.qint4,
|
207 |
+
"qint8": QuantizationDtype.qint8,
|
208 |
+
}.get(quant_text_enc, None),
|
209 |
+
ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
|
210 |
+
prequantized_flow=prequantized_flow,
|
211 |
+
quantize_modulation=quantize_modulation,
|
212 |
+
quantize_flow_embedder_layers=quantize_flow_embedder_layers,
|
213 |
+
)
|
214 |
+
|
215 |
+
|
216 |
+
def load_config_from_path(path: str) -> ModelSpec:
|
217 |
+
path_path = Path(path)
|
218 |
+
if not path_path.exists():
|
219 |
+
raise ValueError(f"Path {path} does not exist")
|
220 |
+
if not path_path.is_file():
|
221 |
+
raise ValueError(f"Path {path} is not a file")
|
222 |
+
return ModelSpec(**json.loads(path_path.read_text()))
|
223 |
+
|
224 |
+
|
225 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
226 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
227 |
+
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
228 |
+
logger.warning("\n" + "-" * 79 + "\n")
|
229 |
+
logger.warning(
|
230 |
+
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
|
231 |
+
)
|
232 |
+
elif len(missing) > 0:
|
233 |
+
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
234 |
+
elif len(unexpected) > 0:
|
235 |
+
logger.warning(
|
236 |
+
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
def load_flow_model(config: ModelSpec) -> Flux:
|
241 |
+
ckpt_path = config.ckpt_path
|
242 |
+
FluxClass = Flux
|
243 |
+
|
244 |
+
with torch.device("meta"):
|
245 |
+
model = FluxClass(config, dtype=into_dtype(config.flow_dtype))
|
246 |
+
if not config.prequantized_flow:
|
247 |
+
model.type(into_dtype(config.flow_dtype))
|
248 |
+
|
249 |
+
if ckpt_path is not None:
|
250 |
+
# load_sft doesn't support torch.device
|
251 |
+
sd = load_sft(ckpt_path, device="cpu")
|
252 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
253 |
+
print_load_warning(missing, unexpected)
|
254 |
+
if not config.prequantized_flow:
|
255 |
+
model.type(into_dtype(config.flow_dtype))
|
256 |
+
return model
|
257 |
+
|
258 |
+
|
259 |
+
def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
|
260 |
+
clip = HFEmbedder(
|
261 |
+
config.clip_path,
|
262 |
+
max_length=77,
|
263 |
+
torch_dtype=into_dtype(config.text_enc_dtype),
|
264 |
+
device=into_device(config.text_enc_device).index or 0,
|
265 |
+
is_clip=True,
|
266 |
+
quantization_dtype=config.clip_quantization_dtype,
|
267 |
+
)
|
268 |
+
t5 = HFEmbedder(
|
269 |
+
config.text_enc_path,
|
270 |
+
max_length=config.text_enc_max_length,
|
271 |
+
torch_dtype=into_dtype(config.text_enc_dtype),
|
272 |
+
device=into_device(config.text_enc_device).index or 0,
|
273 |
+
quantization_dtype=config.text_enc_quantization_dtype,
|
274 |
+
)
|
275 |
+
return clip, t5
|
276 |
+
|
277 |
+
|
278 |
+
def load_autoencoder(config: ModelSpec) -> AutoEncoder:
|
279 |
+
ckpt_path = config.ae_path
|
280 |
+
with torch.device("meta" if ckpt_path is not None else config.ae_device):
|
281 |
+
ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype))
|
282 |
+
|
283 |
+
if ckpt_path is not None:
|
284 |
+
sd = load_sft(ckpt_path, device=str(config.ae_device))
|
285 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
286 |
+
print_load_warning(missing, unexpected)
|
287 |
+
ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype))
|
288 |
+
if config.ae_quantization_dtype is not None:
|
289 |
+
from float8_quantize import recursive_swap_linears
|
290 |
+
|
291 |
+
recursive_swap_linears(ae)
|
292 |
+
if config.offload_vae:
|
293 |
+
ae.to("cpu")
|
294 |
+
torch.cuda.empty_cache()
|
295 |
+
return ae
|
296 |
+
|
297 |
+
|
298 |
+
class LoadedModels(BaseModel):
|
299 |
+
flow: Flux
|
300 |
+
ae: AutoEncoder
|
301 |
+
clip: HFEmbedder
|
302 |
+
t5: HFEmbedder
|
303 |
+
config: ModelSpec
|
304 |
+
|
305 |
+
model_config = {
|
306 |
+
"arbitrary_types_allowed": True,
|
307 |
+
"use_enum_values": True,
|
308 |
+
}
|
309 |
+
|
310 |
+
|
311 |
+
def load_models_from_config_path(
|
312 |
+
path: str,
|
313 |
+
) -> LoadedModels:
|
314 |
+
config = load_config_from_path(path)
|
315 |
+
clip, t5 = load_text_encoders(config)
|
316 |
+
return LoadedModels(
|
317 |
+
flow=load_flow_model(config),
|
318 |
+
ae=load_autoencoder(config),
|
319 |
+
clip=clip,
|
320 |
+
t5=t5,
|
321 |
+
config=config,
|
322 |
+
)
|
323 |
+
|
324 |
+
|
325 |
+
def load_models_from_config(config: ModelSpec) -> LoadedModels:
|
326 |
+
clip, t5 = load_text_encoders(config)
|
327 |
+
return LoadedModels(
|
328 |
+
flow=load_flow_model(config),
|
329 |
+
ae=load_autoencoder(config),
|
330 |
+
clip=clip,
|
331 |
+
t5=t5,
|
332 |
+
config=config,
|
333 |
+
)
|