marks commited on
Commit
3778bc0
·
1 Parent(s): 002c192
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.jpg
3
+ *.png
4
+ *.jpeg
5
+ *.gif
6
+ *.bmp
7
+ *.webp
8
+ *.mp4
9
+ *.mp3
10
+ *.mp3
11
+ *.txt
12
+ .copilotignore
13
+ .misc
14
+ BFL-flux-diffusers
15
+ .env
16
+ .env.*
17
+ perfection.safetensors
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with Python 3.11, PyTorch, CUDA 12.4.1, and Ubuntu 22.04
2
+ FROM runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04
3
+
4
+ # Set the working directory inside the container
5
+ WORKDIR /workspace
6
+
7
+ # Install necessary packages and dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ wget \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Clone your repository
14
+ RUN git clone https://github.com/Yuanshi9815/OminiControl
15
+
16
+ # Change directory to the cloned repo
17
+ WORKDIR /workspace/fp8
18
+
19
+ # Install Python dependencies
20
+ RUN pip install -r requirements.txt
21
+
22
+ # Download the required model files
23
+ RUN wget https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors -O /workspace/flux1-schnell.safetensors && \
24
+ wget https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors -O /workspace/ae.safetensors
25
+
26
+ # Expose necessary HTTP ports
27
+ EXPOSE 8888 7860
28
+
29
+ # Set the command to run your Python script
30
+ CMD ["python", "main_gr.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 Alex Redden
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
api.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, TYPE_CHECKING
2
+
3
+ import numpy as np
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import StreamingResponse, JSONResponse
6
+ from pydantic import BaseModel, Field
7
+ from platform import system
8
+
9
+ if TYPE_CHECKING:
10
+ from flux_pipeline import FluxPipeline
11
+
12
+ if system() == "Windows":
13
+ MAX_RAND = 2**16 - 1
14
+ else:
15
+ MAX_RAND = 2**32 - 1
16
+
17
+
18
+ class AppState:
19
+ model: "FluxPipeline"
20
+
21
+
22
+ class FastAPIApp(FastAPI):
23
+ state: AppState
24
+
25
+
26
+ class LoraArgs(BaseModel):
27
+ scale: Optional[float] = 1.0
28
+ path: Optional[str] = None
29
+ name: Optional[str] = None
30
+ action: Optional[Literal["load", "unload"]] = "load"
31
+
32
+
33
+ class LoraLoadResponse(BaseModel):
34
+ status: Literal["success", "error"]
35
+ message: Optional[str] = None
36
+
37
+
38
+ class GenerateArgs(BaseModel):
39
+ prompt: str
40
+ width: Optional[int] = Field(default=720)
41
+ height: Optional[int] = Field(default=1024)
42
+ num_steps: Optional[int] = Field(default=24)
43
+ guidance: Optional[float] = Field(default=3.5)
44
+ seed: Optional[int] = Field(
45
+ default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND
46
+ )
47
+ strength: Optional[float] = 1.0
48
+ init_image: Optional[str] = None
49
+
50
+
51
+ app = FastAPIApp()
52
+
53
+
54
+ @app.post("/generate")
55
+ def generate(args: GenerateArgs):
56
+ """
57
+ Generates an image from the Flux flow transformer.
58
+
59
+ Args:
60
+ args (GenerateArgs): Arguments for image generation:
61
+
62
+ - `prompt`: The prompt used for image generation.
63
+
64
+ - `width`: The width of the image.
65
+
66
+ - `height`: The height of the image.
67
+
68
+ - `num_steps`: The number of steps for the image generation.
69
+
70
+ - `guidance`: The guidance for image generation, represents the
71
+ influence of the prompt on the image generation.
72
+
73
+ - `seed`: The seed for the image generation.
74
+
75
+ - `strength`: strength for image generation, 0.0 - 1.0.
76
+ Represents the percent of diffusion steps to run,
77
+ setting the init_image as the noised latent at the
78
+ given number of steps.
79
+
80
+ - `init_image`: Base64 encoded image or path to image to use as the init image.
81
+
82
+ Returns:
83
+ StreamingResponse: The generated image as streaming jpeg bytes.
84
+ """
85
+ result = app.state.model.generate(**args.model_dump())
86
+ return StreamingResponse(result, media_type="image/jpeg")
87
+
88
+
89
+ @app.post("/lora", response_model=LoraLoadResponse)
90
+ def lora_action(args: LoraArgs):
91
+ """
92
+ Loads or unloads a LoRA checkpoint into / from the Flux flow transformer.
93
+
94
+ Args:
95
+ args (LoraArgs): Arguments for the LoRA action:
96
+
97
+ - `scale`: The scaling factor for the LoRA weights.
98
+ - `path`: The path to the LoRA checkpoint.
99
+ - `name`: The name of the LoRA checkpoint.
100
+ - `action`: The action to perform, either "load" or "unload".
101
+
102
+ Returns:
103
+ LoraLoadResponse: The status of the LoRA action.
104
+ """
105
+ try:
106
+ if args.action == "load":
107
+ app.state.model.load_lora(args.path, args.scale, args.name)
108
+ elif args.action == "unload":
109
+ app.state.model.unload_lora(args.name if args.name else args.path)
110
+ else:
111
+ return JSONResponse(
112
+ content={
113
+ "status": "error",
114
+ "message": f"Invalid action, expected 'load' or 'unload', got {args.action}",
115
+ },
116
+ status_code=400,
117
+ )
118
+ except Exception as e:
119
+ return JSONResponse(
120
+ status_code=500, content={"status": "error", "message": str(e)}
121
+ )
122
+ return JSONResponse(status_code=200, content={"status": "success"})
app.py CHANGED
@@ -1,7 +1,170 @@
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import spaces
4
+ import torch
5
+ from safetensors.torch import load_file
6
+ from flux_pipeline import FluxPipeline
7
  import gradio as gr
8
+ from PIL import Image
9
 
10
+ def download_models():
11
+ """
12
+ Download required models at application startup using wget.
13
+ """
14
+ model_urls = [
15
+ "https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/flux1-dev.safetensors",
16
+ "https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/ae.safetensors",
17
+ ]
18
+ for url in model_urls:
19
+ filename = url.split("/")[-1]
20
+ if not os.path.exists(filename):
21
+ print(f"Downloading {filename}...")
22
+ subprocess.run(["wget", "-O", filename, url], check=True)
23
+ else:
24
+ print(f"{filename} already exists, skipping download.")
25
 
26
+ print("All models are ready.")
27
+
28
+
29
+ def load_sft(ckpt_path, device="cpu"):
30
+ """
31
+ Load a safetensors file.
32
+ Args:
33
+ ckpt_path (str): Local path to the safetensors file.
34
+ device (str): Device to load the file onto.
35
+ Returns:
36
+ Safetensors model state dictionary.
37
+ """
38
+ if os.path.exists(ckpt_path):
39
+ print(f"Loading local checkpoint: {ckpt_path}")
40
+ return load_file(ckpt_path, device=device)
41
+ else:
42
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
43
+
44
+
45
+ def create_demo(config_path: str):
46
+ generator = FluxPipeline.load_pipeline_from_config_path(config_path)
47
+ load_sft("photo.safetenors", "cuda")
48
+ load_sft("dark.safetensors", "cuda")
49
+ load_sft("perfection.safetensors", "cuda")
50
+ @spaces.GPU
51
+ def generate_image(
52
+ prompt,
53
+ width,
54
+ height,
55
+ num_steps,
56
+ guidance,
57
+ seed,
58
+ init_image,
59
+ image2image_strength,
60
+ add_sampling_metadata,
61
+ ):
62
+
63
+ seed = int(seed)
64
+ if seed == -1:
65
+ seed = None
66
+ out = generator.generate(
67
+ prompt,
68
+ width,
69
+ height,
70
+ num_steps=num_steps,
71
+ guidance=guidance,
72
+ seed=seed,
73
+ init_image=init_image,
74
+ strength=image2image_strength,
75
+ silent=False,
76
+ num_images=1,
77
+ return_seed=True,
78
+ )
79
+ image_bytes = out[0]
80
+ return Image.open(image_bytes), str(out[1]), None
81
+
82
+ is_schnell = generator.config.version == "flux-schnell"
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown(f"# Flux Image Generation Demo - Model: {generator.config.version}")
86
+
87
+ with gr.Row():
88
+ with gr.Column():
89
+ prompt = gr.Textbox(
90
+ label="Prompt",
91
+ value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
92
+ )
93
+ do_img2img = gr.Checkbox(
94
+ label="Image to Image", value=False, interactive=not is_schnell
95
+ )
96
+ init_image = gr.Image(label="Input Image", visible=False)
97
+ image2image_strength = gr.Slider(
98
+ 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
99
+ )
100
+
101
+ with gr.Accordion("Advanced Options", open=False):
102
+ width = gr.Slider(128, 8192, 1152, step=16, label="Width")
103
+ height = gr.Slider(128, 8192, 640, step=16, label="Height")
104
+ num_steps = gr.Slider(
105
+ 1, 50, 4 if is_schnell else 20, step=1, label="Number of steps"
106
+ )
107
+ guidance = gr.Slider(
108
+ 1.0,
109
+ 10.0,
110
+ 3.5,
111
+ step=0.1,
112
+ label="Guidance",
113
+ interactive=not is_schnell,
114
+ )
115
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
116
+ add_sampling_metadata = gr.Checkbox(
117
+ label="Add sampling parameters to metadata?", value=True
118
+ )
119
+
120
+ generate_btn = gr.Button("Generate")
121
+
122
+ with gr.Column(min_width="960px"):
123
+ output_image = gr.Image(label="Generated Image")
124
+ seed_output = gr.Number(label="Used Seed")
125
+ warning_text = gr.Textbox(label="Warning", visible=False)
126
+
127
+ def update_img2img(do_img2img):
128
+ return {
129
+ init_image: gr.update(visible=do_img2img),
130
+ image2image_strength: gr.update(visible=do_img2img),
131
+ }
132
+
133
+ do_img2img.change(
134
+ update_img2img, do_img2img, [init_image, image2image_strength]
135
+ )
136
+
137
+ generate_btn.click(
138
+ fn=generate_image,
139
+ inputs=[
140
+ prompt,
141
+ width,
142
+ height,
143
+ num_steps,
144
+ guidance,
145
+ seed,
146
+ init_image,
147
+ image2image_strength,
148
+ add_sampling_metadata,
149
+ ],
150
+ outputs=[output_image, seed_output, warning_text],
151
+ )
152
+
153
+ return demo
154
+
155
+
156
+ if __name__ == "__main__":
157
+ import argparse
158
+
159
+ parser = argparse.ArgumentParser(description="Flux")
160
+ parser.add_argument(
161
+ "--config", type=str, default="configs/config-dev-1-RTX6000ADA.json", help="Config file path"
162
+ )
163
+ parser.add_argument(
164
+ "--share", action="store_true", help="Create a public link to your demo"
165
+ )
166
+
167
+ args = parser.parse_args()
168
+
169
+ demo = create_demo(args.config)
170
+ demo.launch(share=args.share)
configs/config-dev-1-RTX6000ADA-Copy1.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
+ "compile_extras": true,
53
+ "compile_blocks": true,
54
+ "offload_text_encoder": false,
55
+ "offload_vae": false,
56
+ "offload_flow": false
57
+ }
configs/config-dev-1-RTX6000ADA.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
+ "compile_extras": true,
53
+ "compile_blocks": true,
54
+ "offload_text_encoder": false,
55
+ "offload_vae": false,
56
+ "offload_flow": false
57
+ }
configs/config-dev-cuda0.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "text_enc_quantization_dtype": "qfloat8",
51
+ "compile_extras": false,
52
+ "compile_blocks": false,
53
+ "offload_ae": false,
54
+ "offload_text_enc": false,
55
+ "offload_flow": false
56
+ }
configs/config-dev-eval.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
+ "compile_extras": false,
53
+ "compile_blocks": false,
54
+ "offload_ae": false,
55
+ "offload_text_enc": false,
56
+ "offload_flow": false
57
+ }
configs/config-dev-gigaquant.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 220,
51
+ "flow_quantization_dtype": "qint4",
52
+ "text_enc_quantization_dtype": "qint4",
53
+ "ae_quantization_dtype": "qint4",
54
+ "clip_quantization_dtype": "qint4",
55
+ "compile_extras": false,
56
+ "compile_blocks": false,
57
+ "quantize_extras": true
58
+ }
configs/config-dev-offload-1-4080.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qint4",
52
+ "ae_quantization_dtype": "qfloat8",
53
+ "compile_extras": true,
54
+ "compile_blocks": true,
55
+ "offload_text_encoder": true,
56
+ "offload_vae": true,
57
+ "offload_flow": true
58
+ }
configs/config-dev-offload-1-4090.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qint4",
52
+ "ae_quantization_dtype": "qfloat8",
53
+ "compile_extras": true,
54
+ "compile_blocks": true,
55
+ "offload_text_encoder": true,
56
+ "offload_vae": true,
57
+ "offload_flow": false
58
+ }
configs/config-dev-offload.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qint4",
52
+ "ae_quantization_dtype": "qfloat8",
53
+ "compile_extras": false,
54
+ "compile_blocks": false,
55
+ "offload_text_encoder": true,
56
+ "offload_vae": true,
57
+ "offload_flow": true
58
+ }
configs/config-dev-prequant.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/flux-fp16-acc/flux_fp8.safetensors",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "text_enc_quantization_dtype": "qfloat8",
51
+ "compile_extras": false,
52
+ "compile_blocks": false,
53
+ "prequantized_flow": true,
54
+ "offload_ae": false,
55
+ "offload_text_enc": false,
56
+ "offload_flow": false
57
+ }
configs/config-dev.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "text_enc_quantization_dtype": "qfloat8",
51
+ "ae_quantization_dtype": "qfloat8",
52
+ "compile_extras": true,
53
+ "compile_blocks": true,
54
+ "offload_ae": false,
55
+ "offload_text_enc": false,
56
+ "offload_flow": false
57
+ }
configs/config-f8.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-schnell",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [16, 56, 56],
13
+ "theta": 10000,
14
+ "qkv_bias": true,
15
+ "guidance_embed": false
16
+ },
17
+ "ae_params": {
18
+ "resolution": 256,
19
+ "in_channels": 3,
20
+ "ch": 128,
21
+ "out_ch": 3,
22
+ "ch_mult": [1, 2, 4, 4],
23
+ "num_res_blocks": 2,
24
+ "z_channels": 16,
25
+ "scale_factor": 0.3611,
26
+ "shift_factor": 0.1159
27
+ },
28
+ "ckpt_path": "flux1-schnell.safetensors",
29
+ "ae_path": "ae.safetensors",
30
+ "repo_id": "black-forest-labs/FLUX.1-dev",
31
+ "repo_flow": "flux1-dev.sft",
32
+ "repo_ae": "ae.sft",
33
+ "text_enc_max_length": 256,
34
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
35
+ "text_enc_device": "cuda:0",
36
+ "ae_device": "cuda:0",
37
+ "flux_device": "cuda:0",
38
+ "flow_dtype": "float16",
39
+ "ae_dtype": "bfloat16",
40
+ "text_enc_dtype": "bfloat16",
41
+ "flow_quantization_dtype": "qfloat8",
42
+ "text_enc_quantization_dtype": "qfloat8",
43
+ "compile_extras": true,
44
+ "compile_blocks": true,
45
+ "offload_text_encoder": false,
46
+ "offload_vae": false,
47
+ "offload_flow": false
48
+ }
configs/config-schnell-cuda0.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-schnell",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": false
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-schnell",
40
+ "repo_flow": "flux1-schnell.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 256,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "text_enc_quantization_dtype": "qfloat8",
51
+ "ae_quantization_dtype": "qfloat8",
52
+ "compile_extras": false,
53
+ "compile_blocks": false,
54
+ "offload_ae": false,
55
+ "offload_text_enc": false,
56
+ "offload_flow": false
57
+ }
configs/config-schnell.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-schnell",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": false
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-schnell",
40
+ "repo_flow": "flux1-schnell.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 256,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "text_enc_quantization_dtype": "qfloat8",
51
+ "ae_quantization_dtype": "qfloat8",
52
+ "compile_extras": true,
53
+ "compile_blocks": true,
54
+ "offload_ae": false,
55
+ "offload_text_enc": false,
56
+ "offload_flow": false
57
+ }
dark.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c633fc7d1af2452f0680abdc20baa285c43e107ae9a32fbf995c55c13bf0c4dd
3
+ size 39759552
f8.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-schnell",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [16, 56, 56],
13
+ "theta": 10000,
14
+ "qkv_bias": true,
15
+ "guidance_embed": false
16
+ },
17
+ "ae_params": {
18
+ "resolution": 256,
19
+ "in_channels": 3,
20
+ "ch": 128,
21
+ "out_ch": 3,
22
+ "ch_mult": [1, 2, 4, 4],
23
+ "num_res_blocks": 2,
24
+ "z_channels": 16,
25
+ "scale_factor": 0.3611,
26
+ "shift_factor": 0.1159
27
+ },
28
+ "ckpt_path": "flux1-schnell.safetensors",
29
+ "ae_path": "ae.safetensors",
30
+ "repo_id": "black-forest-labs/FLUX.1-dev",
31
+ "repo_flow": "flux1-dev.sft",
32
+ "repo_ae": "ae.sft",
33
+ "text_enc_max_length": 512,
34
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
35
+ "text_enc_device": "cuda:0",
36
+ "ae_device": "cuda:0",
37
+ "flux_device": "cuda:0",
38
+ "flow_dtype": "float16",
39
+ "ae_dtype": "bfloat16",
40
+ "text_enc_dtype": "bfloat16",
41
+ "flow_quantization_dtype": "qfloat8",
42
+ "text_enc_quantization_dtype": "qfloat8",
43
+ "compile_extras": true,
44
+ "compile_blocks": true,
45
+ "offload_text_encoder": false,
46
+ "offload_vae": false,
47
+ "offload_flow": false
48
+ }
float8_quantize.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import init
5
+ import math
6
+ from torch.compiler import is_compiling
7
+ from torch import __version__
8
+ from torch.version import cuda
9
+
10
+ from modules.flux_model import Modulation
11
+
12
+ IS_TORCH_2_4 = __version__ < (2, 4, 9)
13
+ LT_TORCH_2_4 = __version__ < (2, 4)
14
+ if LT_TORCH_2_4:
15
+ if not hasattr(torch, "_scaled_mm"):
16
+ raise RuntimeError(
17
+ "This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later."
18
+ )
19
+ CUDA_VERSION = float(cuda) if cuda else 0
20
+ if CUDA_VERSION < 12.4:
21
+ raise RuntimeError(
22
+ f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}."
23
+ )
24
+ try:
25
+ from cublas_ops import CublasLinear
26
+ except ImportError:
27
+ CublasLinear = type(None)
28
+
29
+
30
+ class F8Linear(nn.Module):
31
+
32
+ def __init__(
33
+ self,
34
+ in_features: int,
35
+ out_features: int,
36
+ bias: bool = True,
37
+ device=None,
38
+ dtype=torch.float16,
39
+ float8_dtype=torch.float8_e4m3fn,
40
+ float_weight: torch.Tensor = None,
41
+ float_bias: torch.Tensor = None,
42
+ num_scale_trials: int = 12,
43
+ input_float8_dtype=torch.float8_e5m2,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.in_features = in_features
47
+ self.out_features = out_features
48
+ self.float8_dtype = float8_dtype
49
+ self.input_float8_dtype = input_float8_dtype
50
+ self.input_scale_initialized = False
51
+ self.weight_initialized = False
52
+ self.max_value = torch.finfo(self.float8_dtype).max
53
+ self.input_max_value = torch.finfo(self.input_float8_dtype).max
54
+ factory_kwargs = {"dtype": dtype, "device": device}
55
+ if float_weight is None:
56
+ self.weight = nn.Parameter(
57
+ torch.empty((out_features, in_features), **factory_kwargs)
58
+ )
59
+ else:
60
+ self.weight = nn.Parameter(
61
+ float_weight, requires_grad=float_weight.requires_grad
62
+ )
63
+ if float_bias is None:
64
+ if bias:
65
+ self.bias = nn.Parameter(
66
+ torch.empty(out_features, **factory_kwargs),
67
+ )
68
+ else:
69
+ self.register_parameter("bias", None)
70
+ else:
71
+ self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad)
72
+ self.num_scale_trials = num_scale_trials
73
+ self.input_amax_trials = torch.zeros(
74
+ num_scale_trials, requires_grad=False, device=device, dtype=torch.float32
75
+ )
76
+ self.trial_index = 0
77
+ self.register_buffer("scale", None)
78
+ self.register_buffer(
79
+ "input_scale",
80
+ None,
81
+ )
82
+ self.register_buffer(
83
+ "float8_data",
84
+ None,
85
+ )
86
+ self.scale_reciprocal = self.register_buffer("scale_reciprocal", None)
87
+ self.input_scale_reciprocal = self.register_buffer(
88
+ "input_scale_reciprocal", None
89
+ )
90
+
91
+ def _load_from_state_dict(
92
+ self,
93
+ state_dict,
94
+ prefix,
95
+ local_metadata,
96
+ strict,
97
+ missing_keys,
98
+ unexpected_keys,
99
+ error_msgs,
100
+ ):
101
+ sd = {k.replace(prefix, ""): v for k, v in state_dict.items()}
102
+ if "weight" in sd:
103
+ if (
104
+ "float8_data" not in sd
105
+ or sd["float8_data"] is None
106
+ and sd["weight"].shape == (self.out_features, self.in_features)
107
+ ):
108
+ # Initialize as if it's an F8Linear that needs to be quantized
109
+ self._parameters["weight"] = nn.Parameter(
110
+ sd["weight"], requires_grad=False
111
+ )
112
+ if "bias" in sd:
113
+ self._parameters["bias"] = nn.Parameter(
114
+ sd["bias"], requires_grad=False
115
+ )
116
+ self.quantize_weight()
117
+ elif sd["float8_data"].shape == (
118
+ self.out_features,
119
+ self.in_features,
120
+ ) and sd["weight"] == torch.zeros_like(sd["weight"]):
121
+ w = sd["weight"]
122
+ # Set the init values as if it's already quantized float8_data
123
+ self._buffers["float8_data"] = sd["float8_data"]
124
+ self._parameters["weight"] = nn.Parameter(
125
+ torch.zeros(
126
+ 1,
127
+ dtype=w.dtype,
128
+ device=w.device,
129
+ requires_grad=False,
130
+ )
131
+ )
132
+ if "bias" in sd:
133
+ self._parameters["bias"] = nn.Parameter(
134
+ sd["bias"], requires_grad=False
135
+ )
136
+ self.weight_initialized = True
137
+
138
+ # Check if scales and reciprocals are initialized
139
+ if all(
140
+ key in sd
141
+ for key in [
142
+ "scale",
143
+ "input_scale",
144
+ "scale_reciprocal",
145
+ "input_scale_reciprocal",
146
+ ]
147
+ ):
148
+ self.scale = sd["scale"].float()
149
+ self.input_scale = sd["input_scale"].float()
150
+ self.scale_reciprocal = sd["scale_reciprocal"].float()
151
+ self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
152
+ self.input_scale_initialized = True
153
+ self.trial_index = self.num_scale_trials
154
+ elif "scale" in sd and "scale_reciprocal" in sd:
155
+ self.scale = sd["scale"].float()
156
+ self.input_scale = (
157
+ sd["input_scale"].float() if "input_scale" in sd else None
158
+ )
159
+ self.scale_reciprocal = sd["scale_reciprocal"].float()
160
+ self.input_scale_reciprocal = (
161
+ sd["input_scale_reciprocal"].float()
162
+ if "input_scale_reciprocal" in sd
163
+ else None
164
+ )
165
+ self.input_scale_initialized = (
166
+ True if "input_scale" in sd else False
167
+ )
168
+ self.trial_index = (
169
+ self.num_scale_trials if "input_scale" in sd else 0
170
+ )
171
+ self.input_amax_trials = torch.zeros(
172
+ self.num_scale_trials,
173
+ requires_grad=False,
174
+ dtype=torch.float32,
175
+ device=self.weight.device,
176
+ )
177
+ self.input_scale_initialized = False
178
+ self.trial_index = 0
179
+ else:
180
+ # If scales are not initialized, reset trials
181
+ self.input_scale_initialized = False
182
+ self.trial_index = 0
183
+ self.input_amax_trials = torch.zeros(
184
+ self.num_scale_trials, requires_grad=False, dtype=torch.float32
185
+ )
186
+ else:
187
+ raise RuntimeError(
188
+ f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}"
189
+ )
190
+ else:
191
+ raise RuntimeError(
192
+ "Weight tensor not found or has incorrect shape in state dict"
193
+ )
194
+
195
+ def quantize_weight(self):
196
+ if self.weight_initialized:
197
+ return
198
+ amax = torch.max(torch.abs(self.weight.data)).float()
199
+ self.scale = self.amax_to_scale(amax, self.max_value)
200
+ self.float8_data = self.to_fp8_saturated(
201
+ self.weight.data, self.scale, self.max_value
202
+ ).to(self.float8_dtype)
203
+ self.scale_reciprocal = self.scale.reciprocal()
204
+ self.weight.data = torch.zeros(
205
+ 1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
206
+ )
207
+ self.weight_initialized = True
208
+
209
+ def set_weight_tensor(self, tensor: torch.Tensor):
210
+ self.weight.data = tensor
211
+ self.weight_initialized = False
212
+ self.quantize_weight()
213
+
214
+ def amax_to_scale(self, amax, max_val):
215
+ return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val)
216
+
217
+ def to_fp8_saturated(self, x, scale, max_val):
218
+ return (x * scale).clamp(-max_val, max_val)
219
+
220
+ def quantize_input(self, x: torch.Tensor):
221
+ if self.input_scale_initialized:
222
+ return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
223
+ self.input_float8_dtype
224
+ )
225
+ elif self.trial_index < self.num_scale_trials:
226
+
227
+ amax = torch.max(torch.abs(x)).float()
228
+
229
+ self.input_amax_trials[self.trial_index] = amax
230
+ self.trial_index += 1
231
+ self.input_scale = self.amax_to_scale(
232
+ self.input_amax_trials[: self.trial_index].max(), self.input_max_value
233
+ )
234
+ self.input_scale_reciprocal = self.input_scale.reciprocal()
235
+ return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
236
+ self.input_float8_dtype
237
+ )
238
+ else:
239
+ self.input_scale = self.amax_to_scale(
240
+ self.input_amax_trials.max(), self.input_max_value
241
+ )
242
+ self.input_scale_reciprocal = self.input_scale.reciprocal()
243
+ self.input_scale_initialized = True
244
+ return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
245
+ self.input_float8_dtype
246
+ )
247
+
248
+ def reset_parameters(self) -> None:
249
+ if self.weight_initialized:
250
+ self.weight = nn.Parameter(
251
+ torch.empty(
252
+ (self.out_features, self.in_features),
253
+ **{
254
+ "dtype": self.weight.dtype,
255
+ "device": self.weight.device,
256
+ },
257
+ )
258
+ )
259
+ self.weight_initialized = False
260
+ self.input_scale_initialized = False
261
+ self.trial_index = 0
262
+ self.input_amax_trials.zero_()
263
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
264
+ if self.bias is not None:
265
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
266
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
267
+ init.uniform_(self.bias, -bound, bound)
268
+ self.quantize_weight()
269
+ self.max_value = torch.finfo(self.float8_dtype).max
270
+ self.input_max_value = torch.finfo(self.input_float8_dtype).max
271
+
272
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
273
+ if self.input_scale_initialized or is_compiling():
274
+ x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
275
+ self.input_float8_dtype
276
+ )
277
+ else:
278
+ x = self.quantize_input(x)
279
+
280
+ prev_dims = x.shape[:-1]
281
+ x = x.view(-1, self.in_features)
282
+
283
+ # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
284
+ out = torch._scaled_mm(
285
+ x,
286
+ self.float8_data.T,
287
+ scale_a=self.input_scale_reciprocal,
288
+ scale_b=self.scale_reciprocal,
289
+ bias=self.bias,
290
+ out_dtype=self.weight.dtype,
291
+ use_fast_accum=True,
292
+ )
293
+ if IS_TORCH_2_4:
294
+ out = out[0]
295
+ out = out.view(*prev_dims, self.out_features)
296
+ return out
297
+
298
+ @classmethod
299
+ def from_linear(
300
+ cls,
301
+ linear: nn.Linear,
302
+ float8_dtype=torch.float8_e4m3fn,
303
+ input_float8_dtype=torch.float8_e5m2,
304
+ ) -> "F8Linear":
305
+ f8_lin = cls(
306
+ in_features=linear.in_features,
307
+ out_features=linear.out_features,
308
+ bias=linear.bias is not None,
309
+ device=linear.weight.device,
310
+ dtype=linear.weight.dtype,
311
+ float8_dtype=float8_dtype,
312
+ float_weight=linear.weight.data,
313
+ float_bias=(linear.bias.data if linear.bias is not None else None),
314
+ input_float8_dtype=input_float8_dtype,
315
+ )
316
+ f8_lin.quantize_weight()
317
+ return f8_lin
318
+
319
+
320
+ @torch.inference_mode()
321
+ def recursive_swap_linears(
322
+ model: nn.Module,
323
+ float8_dtype=torch.float8_e4m3fn,
324
+ input_float8_dtype=torch.float8_e5m2,
325
+ quantize_modulation: bool = True,
326
+ ignore_keys: list[str] = [],
327
+ ) -> None:
328
+ """
329
+ Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
330
+
331
+ This function traverses the model's structure and replaces each nn.Linear
332
+ instance with an F8Linear instance, which uses 8-bit floating point
333
+ quantization for weights. The original linear layer's weights are deleted
334
+ after conversion to save memory.
335
+
336
+ Args:
337
+ model (nn.Module): The PyTorch model to modify.
338
+
339
+ Note:
340
+ This function modifies the model in-place. After calling this function,
341
+ all linear layers in the model will be using 8-bit quantization.
342
+ """
343
+ for name, child in model.named_children():
344
+ if name in ignore_keys:
345
+ continue
346
+ if isinstance(child, Modulation) and not quantize_modulation:
347
+ continue
348
+ if isinstance(child, nn.Linear) and not isinstance(
349
+ child, (F8Linear, CublasLinear)
350
+ ):
351
+
352
+ setattr(
353
+ model,
354
+ name,
355
+ F8Linear.from_linear(
356
+ child,
357
+ float8_dtype=float8_dtype,
358
+ input_float8_dtype=input_float8_dtype,
359
+ ),
360
+ )
361
+ del child
362
+ else:
363
+ recursive_swap_linears(
364
+ child,
365
+ float8_dtype=float8_dtype,
366
+ input_float8_dtype=input_float8_dtype,
367
+ quantize_modulation=quantize_modulation,
368
+ ignore_keys=ignore_keys,
369
+ )
370
+
371
+
372
+ @torch.inference_mode()
373
+ def swap_to_cublaslinear(model: nn.Module):
374
+ if CublasLinear == type(None):
375
+ return
376
+ for name, child in model.named_children():
377
+ if isinstance(child, nn.Linear) and not isinstance(
378
+ child, (F8Linear, CublasLinear)
379
+ ):
380
+ cublas_lin = CublasLinear(
381
+ child.in_features,
382
+ child.out_features,
383
+ bias=child.bias is not None,
384
+ dtype=child.weight.dtype,
385
+ device=child.weight.device,
386
+ )
387
+ cublas_lin.weight.data = child.weight.clone().detach()
388
+ cublas_lin.bias.data = child.bias.clone().detach()
389
+ setattr(model, name, cublas_lin)
390
+ del child
391
+ else:
392
+ swap_to_cublaslinear(child)
393
+
394
+
395
+ @torch.inference_mode()
396
+ def quantize_flow_transformer_and_dispatch_float8(
397
+ flow_model: nn.Module,
398
+ device=torch.device("cuda"),
399
+ float8_dtype=torch.float8_e4m3fn,
400
+ input_float8_dtype=torch.float8_e5m2,
401
+ offload_flow=False,
402
+ swap_linears_with_cublaslinear=True,
403
+ flow_dtype=torch.float16,
404
+ quantize_modulation: bool = True,
405
+ quantize_flow_embedder_layers: bool = True,
406
+ ) -> nn.Module:
407
+ """
408
+ Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
409
+
410
+ Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes.
411
+
412
+ Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
413
+
414
+ After dispatching, if offload_flow is True, offloads the model to cpu.
415
+
416
+ if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs.
417
+ Otherwise will skip the cublaslinear swap.
418
+
419
+ For added extra precision, you can set quantize_flow_embedder_layers to False,
420
+ this helps maintain the output quality of the flow transformer moreso than fully quantizing,
421
+ at the expense of ~512MB more VRAM usage.
422
+
423
+ For added extra precision, you can set quantize_modulation to False,
424
+ this helps maintain the output quality of the flow transformer moreso than fully quantizing,
425
+ at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers.
426
+ """
427
+ for module in flow_model.double_blocks:
428
+ module.to(device)
429
+ module.eval()
430
+ recursive_swap_linears(
431
+ module,
432
+ float8_dtype=float8_dtype,
433
+ input_float8_dtype=input_float8_dtype,
434
+ quantize_modulation=quantize_modulation,
435
+ )
436
+ torch.cuda.empty_cache()
437
+ for module in flow_model.single_blocks:
438
+ module.to(device)
439
+ module.eval()
440
+ recursive_swap_linears(
441
+ module,
442
+ float8_dtype=float8_dtype,
443
+ input_float8_dtype=input_float8_dtype,
444
+ quantize_modulation=quantize_modulation,
445
+ )
446
+ torch.cuda.empty_cache()
447
+ to_gpu_extras = [
448
+ "vector_in",
449
+ "img_in",
450
+ "txt_in",
451
+ "time_in",
452
+ "guidance_in",
453
+ "final_layer",
454
+ "pe_embedder",
455
+ ]
456
+ for module in to_gpu_extras:
457
+ m_extra = getattr(flow_model, module)
458
+ if m_extra is None:
459
+ continue
460
+ m_extra.to(device)
461
+ m_extra.eval()
462
+ if isinstance(m_extra, nn.Linear) and not isinstance(
463
+ m_extra, (F8Linear, CublasLinear)
464
+ ):
465
+ if quantize_flow_embedder_layers:
466
+ setattr(
467
+ flow_model,
468
+ module,
469
+ F8Linear.from_linear(
470
+ m_extra,
471
+ float8_dtype=float8_dtype,
472
+ input_float8_dtype=input_float8_dtype,
473
+ ),
474
+ )
475
+ del m_extra
476
+ elif module != "final_layer":
477
+ if quantize_flow_embedder_layers:
478
+ recursive_swap_linears(
479
+ m_extra,
480
+ float8_dtype=float8_dtype,
481
+ input_float8_dtype=input_float8_dtype,
482
+ quantize_modulation=quantize_modulation,
483
+ )
484
+ torch.cuda.empty_cache()
485
+ if (
486
+ swap_linears_with_cublaslinear
487
+ and flow_dtype == torch.float16
488
+ and CublasLinear != type(None)
489
+ ):
490
+ swap_to_cublaslinear(flow_model)
491
+ elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
492
+ logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
493
+ if offload_flow:
494
+ flow_model.to("cpu")
495
+ torch.cuda.empty_cache()
496
+ return flow_model
flux_emphasis.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+ from pydash import flatten
3
+
4
+ import torch
5
+ from transformers.models.clip.tokenization_clip import CLIPTokenizer
6
+ from einops import repeat
7
+
8
+ if TYPE_CHECKING:
9
+ from flux_pipeline import FluxPipeline
10
+
11
+
12
+ def parse_prompt_attention(text):
13
+ """
14
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
15
+ Accepted tokens are:
16
+ (abc) - increases attention to abc by a multiplier of 1.1
17
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
18
+ [abc] - decreases attention to abc by a multiplier of 1.1
19
+ \\( - literal character '('
20
+ \\[ - literal character '['
21
+ \\) - literal character ')'
22
+ \\] - literal character ']'
23
+ \\ - literal character '\'
24
+ anything else - just text
25
+
26
+ >>> parse_prompt_attention('normal text')
27
+ [['normal text', 1.0]]
28
+ >>> parse_prompt_attention('an (important) word')
29
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
30
+ >>> parse_prompt_attention('(unbalanced')
31
+ [['unbalanced', 1.1]]
32
+ >>> parse_prompt_attention('\\(literal\\]')
33
+ [['(literal]', 1.0]]
34
+ >>> parse_prompt_attention('(unnecessary)(parens)')
35
+ [['unnecessaryparens', 1.1]]
36
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
37
+ [['a ', 1.0],
38
+ ['house', 1.5730000000000004],
39
+ [' ', 1.1],
40
+ ['on', 1.0],
41
+ [' a ', 1.1],
42
+ ['hill', 0.55],
43
+ [', sun, ', 1.1],
44
+ ['sky', 1.4641000000000006],
45
+ ['.', 1.1]]
46
+ """
47
+ import re
48
+
49
+ re_attention = re.compile(
50
+ r"""
51
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
52
+ \)|]|[^\\()\[\]:]+|:
53
+ """,
54
+ re.X,
55
+ )
56
+
57
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
58
+
59
+ res = []
60
+ round_brackets = []
61
+ square_brackets = []
62
+
63
+ round_bracket_multiplier = 1.1
64
+ square_bracket_multiplier = 1 / 1.1
65
+
66
+ def multiply_range(start_position, multiplier):
67
+ for p in range(start_position, len(res)):
68
+ res[p][1] *= multiplier
69
+
70
+ for m in re_attention.finditer(text):
71
+ text = m.group(0)
72
+ weight = m.group(1)
73
+
74
+ if text.startswith("\\"):
75
+ res.append([text[1:], 1.0])
76
+ elif text == "(":
77
+ round_brackets.append(len(res))
78
+ elif text == "[":
79
+ square_brackets.append(len(res))
80
+ elif weight is not None and len(round_brackets) > 0:
81
+ multiply_range(round_brackets.pop(), float(weight))
82
+ elif text == ")" and len(round_brackets) > 0:
83
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
84
+ elif text == "]" and len(square_brackets) > 0:
85
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
86
+ else:
87
+ parts = re.split(re_break, text)
88
+ for i, part in enumerate(parts):
89
+ if i > 0:
90
+ res.append(["BREAK", -1])
91
+ res.append([part, 1.0])
92
+
93
+ for pos in round_brackets:
94
+ multiply_range(pos, round_bracket_multiplier)
95
+
96
+ for pos in square_brackets:
97
+ multiply_range(pos, square_bracket_multiplier)
98
+
99
+ if len(res) == 0:
100
+ res = [["", 1.0]]
101
+
102
+ # merge runs of identical weights
103
+ i = 0
104
+ while i + 1 < len(res):
105
+ if res[i][1] == res[i + 1][1]:
106
+ res[i][0] += res[i + 1][0]
107
+ res.pop(i + 1)
108
+ else:
109
+ i += 1
110
+
111
+ return res
112
+
113
+
114
+ def get_prompts_tokens_with_weights(
115
+ clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False
116
+ ):
117
+ """
118
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
119
+
120
+ Args:
121
+ pipe (CLIPTokenizer)
122
+ A CLIPTokenizer
123
+ prompt (str)
124
+ A prompt string with weights
125
+
126
+ Returns:
127
+ text_tokens (list)
128
+ A list contains token ids
129
+ text_weight (list)
130
+ A list contains the correspodent weight of token ids
131
+
132
+ Example:
133
+ import torch
134
+ from transformers import CLIPTokenizer
135
+
136
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
137
+ "stablediffusionapi/deliberate-v2"
138
+ , subfolder = "tokenizer"
139
+ , dtype = torch.float16
140
+ )
141
+
142
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
143
+ clip_tokenizer = clip_tokenizer
144
+ ,prompt = "a (red:1.5) cat"*70
145
+ )
146
+ """
147
+ texts_and_weights = parse_prompt_attention(prompt)
148
+ text_tokens, text_weights = [], []
149
+ maxlen = clip_tokenizer.model_max_length
150
+ for word, weight in texts_and_weights:
151
+ # tokenize and discard the starting and the ending token
152
+ token = clip_tokenizer(
153
+ word, truncation=False, padding=False, add_special_tokens=False
154
+ ).input_ids
155
+ # so that tokenize whatever length prompt
156
+ # the returned token is a 1d list: [320, 1125, 539, 320]
157
+ if debug:
158
+ print(
159
+ token,
160
+ "|FOR MODEL LEN{}|".format(maxlen),
161
+ clip_tokenizer.decode(
162
+ token, skip_special_tokens=True, clean_up_tokenization_spaces=True
163
+ ),
164
+ )
165
+ # merge the new tokens to the all tokens holder: text_tokens
166
+ text_tokens = [*text_tokens, *token]
167
+
168
+ # each token chunk will come with one weight, like ['red cat', 2.0]
169
+ # need to expand weight for each token.
170
+ chunk_weights = [weight] * len(token)
171
+
172
+ # append the weight back to the weight holder: text_weights
173
+ text_weights = [*text_weights, *chunk_weights]
174
+ return text_tokens, text_weights
175
+
176
+
177
+ def group_tokens_and_weights(
178
+ token_ids: list,
179
+ weights: list,
180
+ pad_last_block=False,
181
+ bos=49406,
182
+ eos=49407,
183
+ max_length=77,
184
+ pad_tokens=True,
185
+ ):
186
+ """
187
+ Produce tokens and weights in groups and pad the missing tokens
188
+
189
+ Args:
190
+ token_ids (list)
191
+ The token ids from tokenizer
192
+ weights (list)
193
+ The weights list from function get_prompts_tokens_with_weights
194
+ pad_last_block (bool)
195
+ Control if fill the last token list to 75 tokens with eos
196
+ Returns:
197
+ new_token_ids (2d list)
198
+ new_weights (2d list)
199
+
200
+ Example:
201
+ token_groups,weight_groups = group_tokens_and_weights(
202
+ token_ids = token_id_list
203
+ , weights = token_weight_list
204
+ )
205
+ """
206
+ # TODO: Possibly need to fix this, since this doesn't seem correct.
207
+ # Ignoring for now since I don't know what the consequences might be
208
+ # if changed to <= instead of <.
209
+ max_len = max_length - 2 if max_length < 77 else max_length
210
+ # this will be a 2d list
211
+ new_token_ids = []
212
+ new_weights = []
213
+ while len(token_ids) >= max_len:
214
+ # get the first 75 tokens
215
+ temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)]
216
+ temp_77_weights = [weights.pop(0) for _ in range(max_len)]
217
+
218
+ # extract token ids and weights
219
+
220
+ if pad_tokens:
221
+ if bos is not None:
222
+ temp_77_token_ids = [bos] + temp_77_token_ids + [eos]
223
+ temp_77_weights = [1.0] + temp_77_weights + [1.0]
224
+ else:
225
+ temp_77_token_ids = temp_77_token_ids + [eos]
226
+ temp_77_weights = temp_77_weights + [1.0]
227
+
228
+ # add 77 token and weights chunk to the holder list
229
+ new_token_ids.append(temp_77_token_ids)
230
+ new_weights.append(temp_77_weights)
231
+
232
+ # padding the left
233
+ if len(token_ids) > 0:
234
+ if pad_tokens:
235
+ padding_len = max_len - len(token_ids) if pad_last_block else 0
236
+
237
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
238
+ new_token_ids.append(temp_77_token_ids)
239
+
240
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
241
+ new_weights.append(temp_77_weights)
242
+ else:
243
+ new_token_ids.append(token_ids)
244
+ new_weights.append(weights)
245
+ return new_token_ids, new_weights
246
+
247
+
248
+ def standardize_tensor(
249
+ input_tensor: torch.Tensor, target_mean: float, target_std: float
250
+ ) -> torch.Tensor:
251
+ """
252
+ This function standardizes an input tensor so that it has a specific mean and standard deviation.
253
+
254
+ Parameters:
255
+ input_tensor (torch.Tensor): The tensor to standardize.
256
+ target_mean (float): The target mean for the tensor.
257
+ target_std (float): The target standard deviation for the tensor.
258
+
259
+ Returns:
260
+ torch.Tensor: The standardized tensor.
261
+ """
262
+
263
+ # First, compute the mean and std of the input tensor
264
+ mean = input_tensor.mean()
265
+ std = input_tensor.std()
266
+
267
+ # Then, standardize the tensor to have a mean of 0 and std of 1
268
+ standardized_tensor = (input_tensor - mean) / std
269
+
270
+ # Finally, scale the tensor to the target mean and std
271
+ output_tensor = standardized_tensor * target_std + target_mean
272
+
273
+ return output_tensor
274
+
275
+
276
+ def apply_weights(
277
+ prompt_tokens: torch.Tensor,
278
+ weight_tensor: torch.Tensor,
279
+ token_embedding: torch.Tensor,
280
+ eos_token_id: int,
281
+ pad_last_block: bool = True,
282
+ ) -> torch.FloatTensor:
283
+ mean = token_embedding.mean()
284
+ std = token_embedding.std()
285
+ if pad_last_block:
286
+ pooled_tensor = token_embedding[
287
+ torch.arange(token_embedding.shape[0], device=token_embedding.device),
288
+ (
289
+ prompt_tokens.to(dtype=torch.int, device=token_embedding.device)
290
+ == eos_token_id
291
+ )
292
+ .int()
293
+ .argmax(dim=-1),
294
+ ]
295
+ else:
296
+ pooled_tensor = token_embedding[:, -1]
297
+
298
+ for j in range(len(weight_tensor)):
299
+ if weight_tensor[j] != 1.0:
300
+ token_embedding[:, j] = (
301
+ pooled_tensor
302
+ + (token_embedding[:, j] - pooled_tensor) * weight_tensor[j]
303
+ )
304
+ return standardize_tensor(token_embedding, mean, std)
305
+
306
+
307
+ @torch.inference_mode()
308
+ def get_weighted_text_embeddings_flux(
309
+ pipe: "FluxPipeline",
310
+ prompt: str = "",
311
+ num_images_per_prompt: int = 1,
312
+ device: Optional[torch.device] = None,
313
+ target_device: Optional[torch.device] = torch.device("cuda:0"),
314
+ target_dtype: Optional[torch.dtype] = torch.bfloat16,
315
+ debug: bool = False,
316
+ ):
317
+ """
318
+ This function can process long prompt with weights, no length limitation
319
+ for Stable Diffusion XL
320
+
321
+ Args:
322
+ pipe (StableDiffusionPipeline)
323
+ prompt (str)
324
+ prompt_2 (str)
325
+ neg_prompt (str)
326
+ neg_prompt_2 (str)
327
+ num_images_per_prompt (int)
328
+ device (torch.device)
329
+ Returns:
330
+ prompt_embeds (torch.Tensor)
331
+ neg_prompt_embeds (torch.Tensor)
332
+ """
333
+ device = device or pipe._execution_device
334
+
335
+ eos = pipe.clip.tokenizer.eos_token_id
336
+ eos_2 = pipe.t5.tokenizer.eos_token_id
337
+ bos = pipe.clip.tokenizer.bos_token_id
338
+ bos_2 = pipe.t5.tokenizer.bos_token_id
339
+
340
+ clip = pipe.clip.hf_module
341
+ t5 = pipe.t5.hf_module
342
+
343
+ tokenizer_clip = pipe.clip.tokenizer
344
+ tokenizer_t5 = pipe.t5.tokenizer
345
+
346
+ t5_length = 512 if pipe.name == "flux-dev" else 256
347
+ clip_length = 77
348
+
349
+ # tokenizer 1
350
+ prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
351
+ tokenizer_clip, prompt, debug=debug
352
+ )
353
+
354
+ # tokenizer 2
355
+ prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
356
+ tokenizer_t5, prompt, debug=debug
357
+ )
358
+
359
+ prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
360
+ prompt_tokens_clip,
361
+ prompt_weights_clip,
362
+ pad_last_block=True,
363
+ bos=bos,
364
+ eos=eos,
365
+ max_length=clip_length,
366
+ )
367
+ prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights(
368
+ prompt_tokens_t5,
369
+ prompt_weights_t5,
370
+ pad_last_block=True,
371
+ bos=bos_2,
372
+ eos=eos_2,
373
+ max_length=t5_length,
374
+ pad_tokens=False,
375
+ )
376
+ prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped)
377
+ prompt_weights_t5 = flatten(prompt_weights_t5_grouped)
378
+ prompt_tokens_clip = flatten(prompt_tokens_clip_grouped)
379
+ prompt_weights_clip = flatten(prompt_weights_clip_grouped)
380
+
381
+ prompt_tokens_clip = tokenizer_clip.decode(
382
+ prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True
383
+ )
384
+ prompt_tokens_clip = tokenizer_clip(
385
+ prompt_tokens_clip,
386
+ add_special_tokens=True,
387
+ padding="max_length",
388
+ truncation=True,
389
+ max_length=clip_length,
390
+ return_tensors="pt",
391
+ ).input_ids.to(device)
392
+ prompt_tokens_t5 = tokenizer_t5.decode(
393
+ prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True
394
+ )
395
+ prompt_tokens_t5 = tokenizer_t5(
396
+ prompt_tokens_t5,
397
+ add_special_tokens=True,
398
+ padding="max_length",
399
+ truncation=True,
400
+ max_length=t5_length,
401
+ return_tensors="pt",
402
+ ).input_ids.to(device)
403
+
404
+ prompt_weights_t5 = torch.cat(
405
+ [
406
+ torch.tensor(prompt_weights_t5, dtype=torch.float32),
407
+ torch.full(
408
+ (t5_length - torch.tensor(prompt_weights_t5).numel(),),
409
+ 1.0,
410
+ dtype=torch.float32,
411
+ ),
412
+ ],
413
+ dim=0,
414
+ ).to(device)
415
+
416
+ clip_embeds = clip(
417
+ prompt_tokens_clip, output_hidden_states=True, attention_mask=None
418
+ )["pooler_output"]
419
+ if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1:
420
+ clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
421
+
422
+ weight_tensor_t5 = torch.tensor(
423
+ flatten(prompt_weights_t5), dtype=torch.float32, device=device
424
+ )
425
+ t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[
426
+ "last_hidden_state"
427
+ ]
428
+ t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
429
+ if debug:
430
+ print(t5_embeds.shape)
431
+ if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
432
+ t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
433
+ txt_ids = torch.zeros(
434
+ num_images_per_prompt,
435
+ t5_embeds.shape[1],
436
+ 3,
437
+ device=target_device,
438
+ dtype=target_dtype,
439
+ )
440
+ t5_embeds = t5_embeds.to(target_device, dtype=target_dtype)
441
+ clip_embeds = clip_embeds.to(target_device, dtype=target_dtype)
442
+
443
+ return (
444
+ clip_embeds,
445
+ t5_embeds,
446
+ txt_ids,
447
+ )
flux_pipeline.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import random
4
+ import warnings
5
+ from typing import TYPE_CHECKING, Callable, List, Optional, OrderedDict, Union
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ warnings.filterwarnings("ignore", category=UserWarning)
11
+ warnings.filterwarnings("ignore", category=FutureWarning)
12
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
13
+ import torch
14
+ from einops import rearrange
15
+
16
+ from flux_emphasis import get_weighted_text_embeddings_flux
17
+
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+ torch.backends.cudnn.allow_tf32 = True
20
+ torch.backends.cudnn.benchmark = True
21
+ torch.backends.cudnn.benchmark_limit = 20
22
+ torch.set_float32_matmul_precision("high")
23
+ from pybase64 import standard_b64decode
24
+ from torch._dynamo import config
25
+ from torch._inductor import config as ind_config
26
+
27
+ config.cache_size_limit = 10000000000
28
+ ind_config.shape_padding = True
29
+ import platform
30
+
31
+ from loguru import logger
32
+ from torchvision.transforms import functional as TF
33
+ from tqdm import tqdm
34
+
35
+ import lora_loading
36
+ from image_encoder import ImageEncoder
37
+ from util import (
38
+ ModelSpec,
39
+ ModelVersion,
40
+ into_device,
41
+ into_dtype,
42
+ load_config_from_path,
43
+ load_models_from_config,
44
+ )
45
+
46
+ if platform.system() == "Windows":
47
+ MAX_RAND = 2**16 - 1
48
+ else:
49
+ MAX_RAND = 2**32 - 1
50
+
51
+
52
+ if TYPE_CHECKING:
53
+ from modules.autoencoder import AutoEncoder
54
+ from modules.conditioner import HFEmbedder
55
+ from modules.flux_model import Flux
56
+
57
+
58
+ class FluxPipeline:
59
+ """
60
+ FluxPipeline is a class that provides a pipeline for generating images using the Flux model.
61
+ It handles input preparation, timestep generation, noise generation, device management
62
+ and model compilation.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ name: str,
68
+ offload: bool = False,
69
+ clip: "HFEmbedder" = None,
70
+ t5: "HFEmbedder" = None,
71
+ model: "Flux" = None,
72
+ ae: "AutoEncoder" = None,
73
+ dtype: torch.dtype = torch.float16,
74
+ verbose: bool = False,
75
+ flux_device: torch.device | str = "cuda:0",
76
+ ae_device: torch.device | str = "cuda:1",
77
+ clip_device: torch.device | str = "cuda:1",
78
+ t5_device: torch.device | str = "cuda:1",
79
+ config: ModelSpec = None,
80
+ debug: bool = False,
81
+ ):
82
+ """
83
+ Initialize the FluxPipeline class.
84
+
85
+ This class is responsible for preparing input tensors for the Flux model, generating
86
+ timesteps and noise, and handling device management for model offloading.
87
+ """
88
+
89
+ if config is None:
90
+ raise ValueError("ModelSpec config is required!")
91
+
92
+ self.debug = debug
93
+ self.name = name
94
+ self.device_flux = into_device(flux_device)
95
+ self.device_ae = into_device(ae_device)
96
+ self.device_clip = into_device(clip_device)
97
+ self.device_t5 = into_device(t5_device)
98
+ self.dtype = into_dtype(dtype)
99
+ self.offload = offload
100
+ self.clip: "HFEmbedder" = clip
101
+ self.t5: "HFEmbedder" = t5
102
+ self.model: "Flux" = model
103
+ self.ae: "AutoEncoder" = ae
104
+ self.rng = torch.Generator(device="cpu")
105
+ self.img_encoder = ImageEncoder()
106
+ self.verbose = verbose
107
+ self.ae_dtype = torch.bfloat16
108
+ self.config = config
109
+ self.offload_text_encoder = config.offload_text_encoder
110
+ self.offload_vae = config.offload_vae
111
+ self.offload_flow = config.offload_flow
112
+ # If models are not offloaded, move them to the appropriate devices
113
+
114
+ if not self.offload_flow:
115
+ self.model.to(self.device_flux)
116
+ if not self.offload_vae:
117
+ self.ae.to(self.device_ae)
118
+ if not self.offload_text_encoder:
119
+ self.clip.to(self.device_clip)
120
+ self.t5.to(self.device_t5)
121
+
122
+ # compile the model if needed
123
+ if config.compile_blocks or config.compile_extras:
124
+ self.compile()
125
+
126
+ def set_seed(
127
+ self, seed: int | None = None, seed_globally: bool = False
128
+ ) -> torch.Generator:
129
+ if isinstance(seed, (int, float)):
130
+ seed = int(abs(seed)) % MAX_RAND
131
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
132
+ elif isinstance(seed, str):
133
+ try:
134
+ seed = abs(int(seed)) % MAX_RAND
135
+ except Exception as e:
136
+ logger.warning(
137
+ f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
138
+ )
139
+ seed = abs(self.rng.seed()) % MAX_RAND
140
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
141
+ else:
142
+ seed = abs(self.rng.seed()) % MAX_RAND
143
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
144
+
145
+ if seed_globally:
146
+ torch.cuda.manual_seed_all(seed)
147
+ np.random.seed(seed)
148
+ random.seed(seed)
149
+ return cuda_generator, seed
150
+
151
+ def load_lora(
152
+ self,
153
+ lora_path: Union[str, OrderedDict[str, torch.Tensor]],
154
+ scale: float,
155
+ name: Optional[str] = None,
156
+ ):
157
+ """
158
+ Loads a LoRA checkpoint into the Flux flow transformer.
159
+
160
+ Currently supports LoRA checkpoints from either diffusers checkpoints which usually start with transformer.[...],
161
+ or loras which contain keys which start with lora_unet_[...].
162
+
163
+ Args:
164
+ lora_path (str | OrderedDict[str, torch.Tensor]): Path to the LoRA checkpoint or an ordered dictionary containing the LoRA weights.
165
+ scale (float): Scaling factor for the LoRA weights.
166
+ name (str): Name of the LoRA checkpoint, optionally can be left as None, since it only acts as an identifier.
167
+ """
168
+ self.model.load_lora(path=lora_path, scale=scale, name=name)
169
+
170
+ def unload_lora(self, path_or_identifier: str):
171
+ """
172
+ Unloads the LoRA checkpoint from the Flux flow transformer.
173
+
174
+ Args:
175
+ path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded.
176
+ """
177
+ self.model.unload_lora(path_or_identifier=path_or_identifier)
178
+
179
+ @torch.inference_mode()
180
+ def compile(self):
181
+ """
182
+ Compiles the model and extras.
183
+
184
+ First, if:
185
+
186
+ - A) Checkpoint which already has float8 quantized weights and tuned input scales.
187
+ In which case, it will not run warmups since it assumes the input scales are already tuned.
188
+
189
+ - B) Checkpoint which has not been quantized, in which case it will be quantized
190
+ and the input scales will be tuned. via running a warmup loop.
191
+ - If the model is flux-schnell, it will run 3 warmup loops since each loop is 4 steps.
192
+ - If the model is flux-dev, it will run 1 warmup loop for 12 steps.
193
+
194
+ """
195
+
196
+ # Run warmups if the checkpoint is not prequantized
197
+ if not self.config.prequantized_flow:
198
+ logger.info("Running warmups for compile...")
199
+ warmup_dict = dict(
200
+ prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
201
+ height=768,
202
+ width=768,
203
+ num_steps=12,
204
+ guidance=3.5,
205
+ seed=10,
206
+ )
207
+ if self.config.version == ModelVersion.flux_schnell:
208
+ warmup_dict["num_steps"] = 4
209
+ for _ in range(3):
210
+ self.generate(**warmup_dict)
211
+ else:
212
+ self.generate(**warmup_dict)
213
+
214
+ # Compile the model and extras
215
+ to_gpu_extras = [
216
+ "vector_in",
217
+ "img_in",
218
+ "txt_in",
219
+ "time_in",
220
+ "guidance_in",
221
+ "final_layer",
222
+ "pe_embedder",
223
+ ]
224
+ if self.config.compile_blocks:
225
+ for block in self.model.double_blocks:
226
+ block.compile()
227
+ for block in self.model.single_blocks:
228
+ block.compile()
229
+ if self.config.compile_extras:
230
+ for extra in to_gpu_extras:
231
+ getattr(self.model, extra).compile()
232
+
233
+ @torch.inference_mode()
234
+ def prepare(
235
+ self,
236
+ img: torch.Tensor,
237
+ prompt: str | list[str],
238
+ target_device: torch.device = torch.device("cuda:0"),
239
+ target_dtype: torch.dtype = torch.float16,
240
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
241
+ """
242
+ Prepare input tensors for the Flux model.
243
+
244
+ This function processes the input image and text prompt, converting them into
245
+ the appropriate format and embedding representations required by the model.
246
+
247
+ Args:
248
+ img (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
249
+ prompt (str | list[str]): Text prompt or list of prompts guiding the image generation.
250
+ target_device (torch.device, optional): The target device for the output tensors.
251
+ Defaults to torch.device("cuda:0").
252
+ target_dtype (torch.dtype, optional): The target data type for the output tensors.
253
+ Defaults to torch.float16.
254
+
255
+ Returns:
256
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
257
+ - img: Processed image tensor.
258
+ - img_ids: Image position IDs.
259
+ - vec: Clip text embedding vector.
260
+ - txt: T5 text embedding hidden states.
261
+ - txt_ids: Text position IDs.
262
+
263
+ Note:
264
+ This function handles the necessary device management for text encoder offloading
265
+ if enabled in the configuration.
266
+ """
267
+ bs, c, h, w = img.shape
268
+ if bs == 1 and not isinstance(prompt, str):
269
+ bs = len(prompt)
270
+ img = img.unfold(2, 2, 2).unfold(3, 2, 2).permute(0, 2, 3, 1, 4, 5)
271
+ img = img.reshape(img.shape[0], -1, img.shape[3] * img.shape[4] * img.shape[5])
272
+ assert img.shape == (
273
+ bs,
274
+ (h // 2) * (w // 2),
275
+ c * 2 * 2,
276
+ ), f"{img.shape} != {(bs, (h//2)*(w//2), c*2*2)}"
277
+ if img.shape[0] == 1 and bs > 1:
278
+ img = img[None].repeat_interleave(bs, dim=0)
279
+
280
+ img_ids = torch.zeros(
281
+ h // 2, w // 2, 3, device=target_device, dtype=target_dtype
282
+ )
283
+ img_ids[..., 1] = (
284
+ img_ids[..., 1]
285
+ + torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None]
286
+ )
287
+ img_ids[..., 2] = (
288
+ img_ids[..., 2]
289
+ + torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :]
290
+ )
291
+
292
+ img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
293
+ if self.offload_text_encoder:
294
+ self.clip.to(device=self.device_clip)
295
+ self.t5.to(device=self.device_t5)
296
+
297
+ # get the text embeddings
298
+ vec, txt, txt_ids = get_weighted_text_embeddings_flux(
299
+ self,
300
+ prompt,
301
+ num_images_per_prompt=bs,
302
+ device=self.device_clip,
303
+ target_device=target_device,
304
+ target_dtype=target_dtype,
305
+ debug=self.debug,
306
+ )
307
+ # offload text encoder to cpu if needed
308
+ if self.offload_text_encoder:
309
+ self.clip.to("cpu")
310
+ self.t5.to("cpu")
311
+ torch.cuda.empty_cache()
312
+ return img, img_ids, vec, txt, txt_ids
313
+
314
+ @torch.inference_mode()
315
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
316
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
317
+
318
+ def get_lin_function(
319
+ self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
320
+ ) -> Callable[[float], float]:
321
+ m = (y2 - y1) / (x2 - x1)
322
+ b = y1 - m * x1
323
+ return lambda x: m * x + b
324
+
325
+ @torch.inference_mode()
326
+ def get_schedule(
327
+ self,
328
+ num_steps: int,
329
+ image_seq_len: int,
330
+ base_shift: float = 0.5,
331
+ max_shift: float = 1.15,
332
+ shift: bool = True,
333
+ ) -> list[float]:
334
+ """Generates a schedule of timesteps for the given number of steps and image sequence length."""
335
+ # extra step for zero
336
+ timesteps = torch.linspace(1, 0, num_steps + 1)
337
+
338
+ # shifting the schedule to favor high timesteps for higher signal images
339
+ if shift:
340
+ # eastimate mu based on linear estimation between two points
341
+ mu = self.get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
342
+ timesteps = self.time_shift(mu, 1.0, timesteps)
343
+
344
+ return timesteps.tolist()
345
+
346
+ @torch.inference_mode()
347
+ def get_noise(
348
+ self,
349
+ num_samples: int,
350
+ height: int,
351
+ width: int,
352
+ generator: torch.Generator,
353
+ dtype=None,
354
+ device=None,
355
+ ) -> torch.Tensor:
356
+ """Generates a latent noise tensor of the given shape and dtype on the given device."""
357
+ if device is None:
358
+ device = self.device_flux
359
+ if dtype is None:
360
+ dtype = self.dtype
361
+ return torch.randn(
362
+ num_samples,
363
+ 16,
364
+ # allow for packing
365
+ 2 * math.ceil(height / 16),
366
+ 2 * math.ceil(width / 16),
367
+ device=device,
368
+ dtype=dtype,
369
+ generator=generator,
370
+ requires_grad=False,
371
+ )
372
+
373
+ @torch.inference_mode()
374
+ def into_bytes(self, x: torch.Tensor, jpeg_quality: int = 99) -> io.BytesIO:
375
+ """Converts the image tensor to bytes."""
376
+ # bring into PIL format and save
377
+ num_images = x.shape[0]
378
+ images: List[torch.Tensor] = []
379
+ for i in range(num_images):
380
+ x = (
381
+ x[i]
382
+ .clamp(-1, 1)
383
+ .add(1.0)
384
+ .mul(127.5)
385
+ .clamp(0, 255)
386
+ .contiguous()
387
+ .type(torch.uint8)
388
+ )
389
+ images.append(x)
390
+ if len(images) == 1:
391
+ im = images[0]
392
+ else:
393
+ im = torch.vstack(images)
394
+
395
+ im = self.img_encoder.encode_torch(im, quality=jpeg_quality)
396
+ images.clear()
397
+ return im
398
+
399
+ @torch.inference_mode()
400
+ def load_init_image_if_needed(
401
+ self, init_image: torch.Tensor | str | Image.Image | np.ndarray
402
+ ) -> torch.Tensor:
403
+ """
404
+ Loads the initial image if it is a string, numpy array, or PIL.Image,
405
+ if torch.Tensor, expects it to be in the correct format and returns it as is.
406
+ """
407
+ if isinstance(init_image, str):
408
+ try:
409
+ init_image = Image.open(init_image)
410
+ except Exception as e:
411
+ init_image = Image.open(
412
+ io.BytesIO(standard_b64decode(init_image.split(",")[-1]))
413
+ )
414
+ init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
415
+ elif isinstance(init_image, np.ndarray):
416
+ init_image = torch.from_numpy(init_image).type(torch.uint8)
417
+ elif isinstance(init_image, Image.Image):
418
+ init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
419
+
420
+ return init_image
421
+
422
+ @torch.inference_mode()
423
+ def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
424
+ """Decodes the latent tensor to the pixel space."""
425
+ if self.offload_vae:
426
+ self.ae.to(self.device_ae)
427
+ x = x.to(self.device_ae)
428
+ else:
429
+ x = x.to(self.device_ae)
430
+ x = self.unpack(x.float(), height, width)
431
+ with torch.autocast(
432
+ device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
433
+ ):
434
+ x = self.ae.decode(x)
435
+ if self.offload_vae:
436
+ self.ae.to("cpu")
437
+ torch.cuda.empty_cache()
438
+ return x
439
+
440
+ def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
441
+ return rearrange(
442
+ x,
443
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
444
+ h=math.ceil(height / 16),
445
+ w=math.ceil(width / 16),
446
+ ph=2,
447
+ pw=2,
448
+ )
449
+
450
+ @torch.inference_mode()
451
+ def resize_center_crop(
452
+ self, img: torch.Tensor, height: int, width: int
453
+ ) -> torch.Tensor:
454
+ """Resizes and crops the image to the given height and width."""
455
+ img = TF.resize(img, min(width, height))
456
+ img = TF.center_crop(img, (height, width))
457
+ return img
458
+
459
+ @torch.inference_mode()
460
+ def preprocess_latent(
461
+ self,
462
+ init_image: torch.Tensor | np.ndarray = None,
463
+ height: int = 720,
464
+ width: int = 1024,
465
+ num_steps: int = 20,
466
+ strength: float = 1.0,
467
+ generator: torch.Generator = None,
468
+ num_images: int = 1,
469
+ ) -> tuple[torch.Tensor, List[float]]:
470
+ """
471
+ Preprocesses the latent tensor for the given number of steps and image sequence length.
472
+ Also, if an initial image is provided, it is vae encoded and injected with the appropriate noise
473
+ given the strength and number of steps replacing the latent tensor.
474
+ """
475
+ # prepare input
476
+
477
+ if init_image is not None:
478
+ if isinstance(init_image, np.ndarray):
479
+ init_image = torch.from_numpy(init_image)
480
+
481
+ init_image = (
482
+ init_image.permute(2, 0, 1)
483
+ .contiguous()
484
+ .to(self.device_ae, dtype=self.ae_dtype)
485
+ .div(127.5)
486
+ .sub(1)[None, ...]
487
+ )
488
+ init_image = self.resize_center_crop(init_image, height, width)
489
+ with torch.autocast(
490
+ device_type=self.device_ae.type,
491
+ dtype=torch.bfloat16,
492
+ cache_enabled=False,
493
+ ):
494
+ if self.offload_vae:
495
+ self.ae.to(self.device_ae)
496
+ init_image = (
497
+ self.ae.encode(init_image)
498
+ .to(dtype=self.dtype, device=self.device_flux)
499
+ .repeat(num_images, 1, 1, 1)
500
+ )
501
+ if self.offload_vae:
502
+ self.ae.to("cpu")
503
+ torch.cuda.empty_cache()
504
+
505
+ x = self.get_noise(
506
+ num_images,
507
+ height,
508
+ width,
509
+ device=self.device_flux,
510
+ dtype=self.dtype,
511
+ generator=generator,
512
+ )
513
+ timesteps = self.get_schedule(
514
+ num_steps=num_steps,
515
+ image_seq_len=x.shape[-1] * x.shape[-2] // 4,
516
+ shift=(self.name != "flux-schnell"),
517
+ )
518
+ if init_image is not None:
519
+ t_idx = int((1 - strength) * num_steps)
520
+ t = timesteps[t_idx]
521
+ timesteps = timesteps[t_idx:]
522
+ x = t * x + (1.0 - t) * init_image
523
+ return x, timesteps
524
+
525
+ @torch.inference_mode()
526
+ def generate(
527
+ self,
528
+ prompt: str,
529
+ width: int = 720,
530
+ height: int = 1024,
531
+ num_steps: int = 24,
532
+ guidance: float = 3.5,
533
+ seed: int | None = None,
534
+ init_image: torch.Tensor | str | Image.Image | np.ndarray | None = None,
535
+ strength: float = 1.0,
536
+ silent: bool = False,
537
+ num_images: int = 1,
538
+ return_seed: bool = False,
539
+ jpeg_quality: int = 99,
540
+ ) -> io.BytesIO:
541
+ """
542
+ Generate images based on the given prompt and parameters.
543
+
544
+ Args:
545
+ prompt `(str)`: The text prompt to guide the image generation.
546
+
547
+ width `(int, optional)`: Width of the generated image. Defaults to 720.
548
+
549
+ height `(int, optional)`: Height of the generated image. Defaults to 1024.
550
+
551
+ num_steps `(int, optional)`: Number of denoising steps. Defaults to 24.
552
+
553
+ guidance `(float, optional)`: Guidance scale for text-to-image generation. Defaults to 3.5.
554
+
555
+ seed `(int | None, optional)`: Random seed for reproducibility. If None, a random seed is used. Defaults to None.
556
+
557
+ init_image `(torch.Tensor | str | Image.Image | np.ndarray | None, optional)`: Initial image for image-to-image generation. Defaults to None.
558
+
559
+ -- note: if the image's height/width do not match the height/width of the generated image, the image is resized and centered cropped to match the height/width arguments.
560
+
561
+ -- If a string is provided, it is assumed to be either a path to an image file or a base64 encoded image.
562
+
563
+ -- If a numpy array is provided, it is assumed to be an RGB numpy array of shape (height, width, 3) and dtype uint8.
564
+
565
+ -- If a PIL.Image is provided, it is assumed to be an RGB PIL.Image.
566
+
567
+ -- If a torch.Tensor is provided, it is assumed to be a torch.Tensor of shape (height, width, 3) and dtype uint8 with range [0, 255].
568
+
569
+ strength `(float, optional)`: Strength of the init_image in image-to-image generation. Defaults to 1.0.
570
+
571
+ silent `(bool, optional)`: If True, suppresses progress bar. Defaults to False.
572
+
573
+ num_images `(int, optional)`: Number of images to generate. Defaults to 1.
574
+
575
+ return_seed `(bool, optional)`: If True, returns the seed along with the generated image. Defaults to False.
576
+
577
+ jpeg_quality `(int, optional)`: Quality of the JPEG compression. Defaults to 99.
578
+
579
+ Returns:
580
+ io.BytesIO: Generated image(s) in bytes format.
581
+ int: Seed used for generation (only if return_seed is True).
582
+ """
583
+ num_steps = 4 if self.name == "flux-schnell" else num_steps
584
+
585
+ init_image = self.load_init_image_if_needed(init_image)
586
+
587
+ # allow for packing and conversion to latent space
588
+ height = 16 * (height // 16)
589
+ width = 16 * (width // 16)
590
+
591
+ generator, seed = self.set_seed(seed)
592
+
593
+ if not silent:
594
+ logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
595
+
596
+ # preprocess the latent
597
+ img, timesteps = self.preprocess_latent(
598
+ init_image=init_image,
599
+ height=height,
600
+ width=width,
601
+ num_steps=num_steps,
602
+ strength=strength,
603
+ generator=generator,
604
+ num_images=num_images,
605
+ )
606
+
607
+ # prepare inputs
608
+ img, img_ids, vec, txt, txt_ids = map(
609
+ lambda x: x.contiguous(),
610
+ self.prepare(
611
+ img=img,
612
+ prompt=prompt,
613
+ target_device=self.device_flux,
614
+ target_dtype=self.dtype,
615
+ ),
616
+ )
617
+
618
+ # this is ignored for schnell
619
+ guidance_vec = torch.full(
620
+ (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
621
+ )
622
+ t_vec = None
623
+ # dispatch to gpu if offloaded
624
+ if self.offload_flow:
625
+ self.model.to(self.device_flux)
626
+
627
+ # perform the denoising loop
628
+ for t_curr, t_prev in tqdm(
629
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
630
+ ):
631
+ if t_vec is None:
632
+ t_vec = torch.full(
633
+ (img.shape[0],),
634
+ t_curr,
635
+ dtype=self.dtype,
636
+ device=self.device_flux,
637
+ )
638
+ else:
639
+ t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
640
+
641
+ pred = self.model.forward(
642
+ img=img,
643
+ img_ids=img_ids,
644
+ txt=txt,
645
+ txt_ids=txt_ids,
646
+ y=vec,
647
+ timesteps=t_vec,
648
+ guidance=guidance_vec,
649
+ )
650
+
651
+ img = img + (t_prev - t_curr) * pred
652
+
653
+ # offload the model to cpu if needed
654
+ if self.offload_flow:
655
+ self.model.to("cpu")
656
+ torch.cuda.empty_cache()
657
+
658
+ # decode latents to pixel space
659
+ img = self.vae_decode(img, height, width)
660
+
661
+ if return_seed:
662
+ return self.into_bytes(img, jpeg_quality=jpeg_quality), seed
663
+ return self.into_bytes(img, jpeg_quality=jpeg_quality)
664
+
665
+ @classmethod
666
+ def load_pipeline_from_config_path(
667
+ cls, path: str, flow_model_path: str = None, debug: bool = False, **kwargs
668
+ ) -> "FluxPipeline":
669
+ with torch.inference_mode():
670
+ config = load_config_from_path(path)
671
+ if flow_model_path:
672
+ config.ckpt_path = flow_model_path
673
+ for k, v in kwargs.items():
674
+ if hasattr(config, k):
675
+ logger.info(
676
+ f"Overriding config {k}:{getattr(config, k)} with value {v}"
677
+ )
678
+ setattr(config, k, v)
679
+ return cls.load_pipeline_from_config(config, debug=debug)
680
+
681
+ @classmethod
682
+ def load_pipeline_from_config(
683
+ cls, config: ModelSpec, debug: bool = False
684
+ ) -> "FluxPipeline":
685
+ from float8_quantize import quantize_flow_transformer_and_dispatch_float8
686
+
687
+ with torch.inference_mode():
688
+ if debug:
689
+ logger.info(
690
+ f"Loading as prequantized flow transformer? {config.prequantized_flow}"
691
+ )
692
+
693
+ models = load_models_from_config(config)
694
+ config = models.config
695
+ flux_device = into_device(config.flux_device)
696
+ ae_device = into_device(config.ae_device)
697
+ clip_device = into_device(config.text_enc_device)
698
+ t5_device = into_device(config.text_enc_device)
699
+ flux_dtype = into_dtype(config.flow_dtype)
700
+ flow_model = models.flow
701
+
702
+ if not config.prequantized_flow:
703
+ flow_model = quantize_flow_transformer_and_dispatch_float8(
704
+ flow_model,
705
+ flux_device,
706
+ offload_flow=config.offload_flow,
707
+ swap_linears_with_cublaslinear=flux_dtype == torch.float16,
708
+ flow_dtype=flux_dtype,
709
+ quantize_modulation=config.quantize_modulation,
710
+ quantize_flow_embedder_layers=config.quantize_flow_embedder_layers,
711
+ )
712
+ else:
713
+ flow_model.eval().requires_grad_(False)
714
+
715
+ return cls(
716
+ name=config.version,
717
+ clip=models.clip,
718
+ t5=models.t5,
719
+ model=flow_model,
720
+ ae=models.ae,
721
+ dtype=flux_dtype,
722
+ verbose=False,
723
+ flux_device=flux_device,
724
+ ae_device=ae_device,
725
+ clip_device=clip_device,
726
+ t5_device=t5_device,
727
+ config=config,
728
+ debug=debug,
729
+ )
image_encoder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class ImageEncoder:
8
+
9
+ @torch.inference_mode()
10
+ def encode_torch(self, img: torch.Tensor, quality=95):
11
+ if img.ndim == 2:
12
+ img = (
13
+ img[None]
14
+ .repeat_interleave(3, dim=0)
15
+ .permute(1, 2, 0)
16
+ .contiguous()
17
+ .clamp(0, 255)
18
+ .type(torch.uint8)
19
+ )
20
+ elif img.ndim == 3:
21
+ if img.shape[0] == 3:
22
+ img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8)
23
+ elif img.shape[2] == 3:
24
+ img = img.contiguous().clamp(0, 255).type(torch.uint8)
25
+ else:
26
+ raise ValueError(f"Unsupported image shape: {img.shape}")
27
+ else:
28
+ raise ValueError(f"Unsupported image num dims: {img.ndim}")
29
+
30
+ img = img.cpu().numpy().astype(np.uint8)
31
+ im = Image.fromarray(img)
32
+ iob = io.BytesIO()
33
+ im.save(iob, format="JPEG", quality=quality)
34
+ iob.seek(0)
35
+ return iob
lora_loading.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Optional, OrderedDict, Tuple, TypeAlias, Union
3
+ import torch
4
+ from loguru import logger
5
+ from safetensors.torch import load_file
6
+ from tqdm import tqdm
7
+ from torch import nn
8
+
9
+ try:
10
+ from cublas_ops import CublasLinear
11
+ except Exception as e:
12
+ CublasLinear = type(None)
13
+ from float8_quantize import F8Linear
14
+ from modules.flux_model import Flux
15
+
16
+ path_regex = re.compile(r"/|\\")
17
+
18
+ StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
19
+
20
+
21
+ class LoraWeights:
22
+ def __init__(
23
+ self,
24
+ weights: StateDict,
25
+ path: str,
26
+ name: str = None,
27
+ scale: float = 1.0,
28
+ ) -> None:
29
+ self.path = path
30
+ self.weights = weights
31
+ self.name = name if name else path_regex.split(path)[-1]
32
+ self.scale = scale
33
+
34
+
35
+ def swap_scale_shift(weight):
36
+ scale, shift = weight.chunk(2, dim=0)
37
+ new_weight = torch.cat([shift, scale], dim=0)
38
+ return new_weight
39
+
40
+
41
+ def check_if_lora_exists(state_dict, lora_name):
42
+ subkey = lora_name.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]
43
+ for key in state_dict.keys():
44
+ if subkey in key:
45
+ return subkey
46
+ return False
47
+
48
+
49
+ def convert_if_lora_exists(new_state_dict, state_dict, lora_name, flux_layer_name):
50
+ if (original_stubkey := check_if_lora_exists(state_dict, lora_name)) != False:
51
+ weights_to_pop = [k for k in state_dict.keys() if original_stubkey in k]
52
+ for key in weights_to_pop:
53
+ key_replacement = key.replace(
54
+ original_stubkey, flux_layer_name.replace(".weight", "")
55
+ )
56
+ new_state_dict[key_replacement] = state_dict.pop(key)
57
+ return new_state_dict, state_dict
58
+ else:
59
+ return new_state_dict, state_dict
60
+
61
+
62
+ def convert_diffusers_to_flux_transformer_checkpoint(
63
+ diffusers_state_dict,
64
+ num_layers,
65
+ num_single_layers,
66
+ has_guidance=True,
67
+ prefix="",
68
+ ):
69
+ original_state_dict = {}
70
+
71
+ # time_text_embed.timestep_embedder -> time_in
72
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
73
+ original_state_dict,
74
+ diffusers_state_dict,
75
+ f"{prefix}time_text_embed.timestep_embedder.linear_1.weight",
76
+ "time_in.in_layer.weight",
77
+ )
78
+ # time_text_embed.text_embedder -> vector_in
79
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
80
+ original_state_dict,
81
+ diffusers_state_dict,
82
+ f"{prefix}time_text_embed.text_embedder.linear_1.weight",
83
+ "vector_in.in_layer.weight",
84
+ )
85
+
86
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
87
+ original_state_dict,
88
+ diffusers_state_dict,
89
+ f"{prefix}time_text_embed.text_embedder.linear_2.weight",
90
+ "vector_in.out_layer.weight",
91
+ )
92
+
93
+ if has_guidance:
94
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
95
+ original_state_dict,
96
+ diffusers_state_dict,
97
+ f"{prefix}time_text_embed.guidance_embedder.linear_1.weight",
98
+ "guidance_in.in_layer.weight",
99
+ )
100
+
101
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
102
+ original_state_dict,
103
+ diffusers_state_dict,
104
+ f"{prefix}time_text_embed.guidance_embedder.linear_2.weight",
105
+ "guidance_in.out_layer.weight",
106
+ )
107
+
108
+ # context_embedder -> txt_in
109
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
110
+ original_state_dict,
111
+ diffusers_state_dict,
112
+ f"{prefix}context_embedder.weight",
113
+ "txt_in.weight",
114
+ )
115
+
116
+ # x_embedder -> img_in
117
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
118
+ original_state_dict,
119
+ diffusers_state_dict,
120
+ f"{prefix}x_embedder.weight",
121
+ "img_in.weight",
122
+ )
123
+ # double transformer blocks
124
+ for i in range(num_layers):
125
+ block_prefix = f"transformer_blocks.{i}."
126
+ # norms
127
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
128
+ original_state_dict,
129
+ diffusers_state_dict,
130
+ f"{prefix}{block_prefix}norm1.linear.weight",
131
+ f"double_blocks.{i}.img_mod.lin.weight",
132
+ )
133
+
134
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
135
+ original_state_dict,
136
+ diffusers_state_dict,
137
+ f"{prefix}{block_prefix}norm1_context.linear.weight",
138
+ f"double_blocks.{i}.txt_mod.lin.weight",
139
+ )
140
+
141
+ # Q, K, V
142
+ temp_dict = {}
143
+
144
+ expected_shape_qkv_a = None
145
+ expected_shape_qkv_b = None
146
+ expected_shape_add_qkv_a = None
147
+ expected_shape_add_qkv_b = None
148
+ dtype = None
149
+ device = None
150
+
151
+ for component in [
152
+ "to_q",
153
+ "to_k",
154
+ "to_v",
155
+ "add_q_proj",
156
+ "add_k_proj",
157
+ "add_v_proj",
158
+ ]:
159
+
160
+ sample_component_A_key = (
161
+ f"{prefix}{block_prefix}attn.{component}.lora_A.weight"
162
+ )
163
+ sample_component_B_key = (
164
+ f"{prefix}{block_prefix}attn.{component}.lora_B.weight"
165
+ )
166
+ if (
167
+ sample_component_A_key in diffusers_state_dict
168
+ and sample_component_B_key in diffusers_state_dict
169
+ ):
170
+ sample_component_A = diffusers_state_dict.pop(sample_component_A_key)
171
+ sample_component_B = diffusers_state_dict.pop(sample_component_B_key)
172
+ temp_dict[f"{component}"] = [sample_component_A, sample_component_B]
173
+ if expected_shape_qkv_a is None and not component.startswith("add_"):
174
+ expected_shape_qkv_a = sample_component_A.shape
175
+ expected_shape_qkv_b = sample_component_B.shape
176
+ dtype = sample_component_A.dtype
177
+ device = sample_component_A.device
178
+ if expected_shape_add_qkv_a is None and component.startswith("add_"):
179
+ expected_shape_add_qkv_a = sample_component_A.shape
180
+ expected_shape_add_qkv_b = sample_component_B.shape
181
+ dtype = sample_component_A.dtype
182
+ device = sample_component_A.device
183
+ else:
184
+ logger.info(
185
+ f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}"
186
+ )
187
+ temp_dict[f"{component}"] = [None, None]
188
+
189
+ if device is not None:
190
+ if expected_shape_qkv_a is not None:
191
+
192
+ if (sq := temp_dict["to_q"])[0] is not None:
193
+ sample_q_A, sample_q_B = sq
194
+ else:
195
+ sample_q_A, sample_q_B = [
196
+ torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
197
+ torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
198
+ ]
199
+ if (sq := temp_dict["to_k"])[0] is not None:
200
+ sample_k_A, sample_k_B = sq
201
+ else:
202
+ sample_k_A, sample_k_B = [
203
+ torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
204
+ torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
205
+ ]
206
+ if (sq := temp_dict["to_v"])[0] is not None:
207
+ sample_v_A, sample_v_B = sq
208
+ else:
209
+ sample_v_A, sample_v_B = [
210
+ torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
211
+ torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
212
+ ]
213
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = (
214
+ torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0)
215
+ )
216
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = (
217
+ torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0)
218
+ )
219
+ if expected_shape_add_qkv_a is not None:
220
+
221
+ if (sq := temp_dict["add_q_proj"])[0] is not None:
222
+ context_q_A, context_q_B = sq
223
+ else:
224
+ context_q_A, context_q_B = [
225
+ torch.zeros(
226
+ expected_shape_add_qkv_a, dtype=dtype, device=device
227
+ ),
228
+ torch.zeros(
229
+ expected_shape_add_qkv_b, dtype=dtype, device=device
230
+ ),
231
+ ]
232
+ if (sq := temp_dict["add_k_proj"])[0] is not None:
233
+ context_k_A, context_k_B = sq
234
+ else:
235
+ context_k_A, context_k_B = [
236
+ torch.zeros(
237
+ expected_shape_add_qkv_a, dtype=dtype, device=device
238
+ ),
239
+ torch.zeros(
240
+ expected_shape_add_qkv_b, dtype=dtype, device=device
241
+ ),
242
+ ]
243
+ if (sq := temp_dict["add_v_proj"])[0] is not None:
244
+ context_v_A, context_v_B = sq
245
+ else:
246
+ context_v_A, context_v_B = [
247
+ torch.zeros(
248
+ expected_shape_add_qkv_a, dtype=dtype, device=device
249
+ ),
250
+ torch.zeros(
251
+ expected_shape_add_qkv_b, dtype=dtype, device=device
252
+ ),
253
+ ]
254
+
255
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = (
256
+ torch.cat([context_q_A, context_k_A, context_v_A], dim=0)
257
+ )
258
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = (
259
+ torch.cat([context_q_B, context_k_B, context_v_B], dim=0)
260
+ )
261
+
262
+ # qk_norm
263
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
264
+ original_state_dict,
265
+ diffusers_state_dict,
266
+ f"{prefix}{block_prefix}attn.norm_q.weight",
267
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale",
268
+ )
269
+
270
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
271
+ original_state_dict,
272
+ diffusers_state_dict,
273
+ f"{prefix}{block_prefix}attn.norm_k.weight",
274
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale",
275
+ )
276
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
277
+ original_state_dict,
278
+ diffusers_state_dict,
279
+ f"{prefix}{block_prefix}attn.norm_added_q.weight",
280
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale",
281
+ )
282
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
283
+ original_state_dict,
284
+ diffusers_state_dict,
285
+ f"{prefix}{block_prefix}attn.norm_added_k.weight",
286
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale",
287
+ )
288
+
289
+ # ff img_mlp
290
+
291
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
292
+ original_state_dict,
293
+ diffusers_state_dict,
294
+ f"{prefix}{block_prefix}ff.net.0.proj.weight",
295
+ f"double_blocks.{i}.img_mlp.0.weight",
296
+ )
297
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
298
+ original_state_dict,
299
+ diffusers_state_dict,
300
+ f"{prefix}{block_prefix}ff.net.2.weight",
301
+ f"double_blocks.{i}.img_mlp.2.weight",
302
+ )
303
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
304
+ original_state_dict,
305
+ diffusers_state_dict,
306
+ f"{prefix}{block_prefix}ff_context.net.0.proj.weight",
307
+ f"double_blocks.{i}.txt_mlp.0.weight",
308
+ )
309
+
310
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
311
+ original_state_dict,
312
+ diffusers_state_dict,
313
+ f"{prefix}{block_prefix}ff_context.net.2.weight",
314
+ f"double_blocks.{i}.txt_mlp.2.weight",
315
+ )
316
+ # output projections
317
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
318
+ original_state_dict,
319
+ diffusers_state_dict,
320
+ f"{prefix}{block_prefix}attn.to_out.0.weight",
321
+ f"double_blocks.{i}.img_attn.proj.weight",
322
+ )
323
+
324
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
325
+ original_state_dict,
326
+ diffusers_state_dict,
327
+ f"{prefix}{block_prefix}attn.to_add_out.weight",
328
+ f"double_blocks.{i}.txt_attn.proj.weight",
329
+ )
330
+
331
+ # single transformer blocks
332
+ for i in range(num_single_layers):
333
+ block_prefix = f"single_transformer_blocks.{i}."
334
+ # norm.linear -> single_blocks.0.modulation.lin
335
+ key_norm = f"{prefix}{block_prefix}norm.linear.weight"
336
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
337
+ original_state_dict,
338
+ diffusers_state_dict,
339
+ key_norm,
340
+ f"single_blocks.{i}.modulation.lin.weight",
341
+ )
342
+
343
+ has_q, has_k, has_v, has_mlp = False, False, False, False
344
+ shape_qkv_a = None
345
+ shape_qkv_b = None
346
+ # Q, K, V, mlp
347
+ q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
348
+ q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
349
+ if q_A is not None and q_B is not None:
350
+ has_q = True
351
+ shape_qkv_a = q_A.shape
352
+ shape_qkv_b = q_B.shape
353
+ k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
354
+ k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
355
+ if k_A is not None and k_B is not None:
356
+ has_k = True
357
+ shape_qkv_a = k_A.shape
358
+ shape_qkv_b = k_B.shape
359
+ v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
360
+ v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
361
+ if v_A is not None and v_B is not None:
362
+ has_v = True
363
+ shape_qkv_a = v_A.shape
364
+ shape_qkv_b = v_B.shape
365
+ mlp_A = diffusers_state_dict.pop(
366
+ f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
367
+ )
368
+ mlp_B = diffusers_state_dict.pop(
369
+ f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
370
+ )
371
+ if mlp_A is not None and mlp_B is not None:
372
+ has_mlp = True
373
+ shape_qkv_a = mlp_A.shape
374
+ shape_qkv_b = mlp_B.shape
375
+ if any([has_q, has_k, has_v, has_mlp]):
376
+ if not has_q:
377
+ q_A, q_B = [
378
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
379
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
380
+ ]
381
+ if not has_k:
382
+ k_A, k_B = [
383
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
384
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
385
+ ]
386
+ if not has_v:
387
+ v_A, v_B = [
388
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
389
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
390
+ ]
391
+ if not has_mlp:
392
+ mlp_A, mlp_B = [
393
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
394
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
395
+ ]
396
+ original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat(
397
+ [q_A, k_A, v_A, mlp_A], dim=0
398
+ )
399
+ original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat(
400
+ [q_B, k_B, v_B, mlp_B], dim=0
401
+ )
402
+
403
+ # output projections
404
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
405
+ original_state_dict,
406
+ diffusers_state_dict,
407
+ f"{prefix}{block_prefix}proj_out.weight",
408
+ f"single_blocks.{i}.linear2.weight",
409
+ )
410
+
411
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
412
+ original_state_dict,
413
+ diffusers_state_dict,
414
+ f"{prefix}proj_out.weight",
415
+ "final_layer.linear.weight",
416
+ )
417
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
418
+ original_state_dict,
419
+ diffusers_state_dict,
420
+ f"{prefix}proj_out.bias",
421
+ "final_layer.linear.bias",
422
+ )
423
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
424
+ original_state_dict,
425
+ diffusers_state_dict,
426
+ f"{prefix}norm_out.linear.weight",
427
+ "final_layer.adaLN_modulation.1.weight",
428
+ )
429
+ if len(list(diffusers_state_dict.keys())) > 0:
430
+ logger.warning("Unexpected keys:", diffusers_state_dict.keys())
431
+
432
+ return original_state_dict
433
+
434
+
435
+ def convert_from_original_flux_checkpoint(original_state_dict: StateDict) -> StateDict:
436
+ """
437
+ Convert the state dict from the original Flux checkpoint format to the new format.
438
+
439
+ Args:
440
+ original_state_dict (Dict[str, torch.Tensor]): The original Flux checkpoint state dict.
441
+
442
+ Returns:
443
+ Dict[str, torch.Tensor]: The converted state dict in the new format.
444
+ """
445
+ sd = {
446
+ k.replace("lora_unet_", "")
447
+ .replace("double_blocks_", "double_blocks.")
448
+ .replace("single_blocks_", "single_blocks.")
449
+ .replace("_img_attn_", ".img_attn.")
450
+ .replace("_txt_attn_", ".txt_attn.")
451
+ .replace("_img_mod_", ".img_mod.")
452
+ .replace("_txt_mod_", ".txt_mod.")
453
+ .replace("_img_mlp_", ".img_mlp.")
454
+ .replace("_txt_mlp_", ".txt_mlp.")
455
+ .replace("_linear1", ".linear1")
456
+ .replace("_linear2", ".linear2")
457
+ .replace("_modulation_", ".modulation.")
458
+ .replace("lora_up", "lora_B")
459
+ .replace("lora_down", "lora_A"): v
460
+ for k, v in original_state_dict.items()
461
+ if "lora" in k
462
+ }
463
+ return sd
464
+
465
+
466
+ def get_module_for_key(
467
+ key: str, model: Flux
468
+ ) -> F8Linear | torch.nn.Linear | CublasLinear:
469
+ parts = key.split(".")
470
+ module = model
471
+ for part in parts:
472
+ module = getattr(module, part)
473
+ return module
474
+
475
+
476
+ def get_lora_for_key(
477
+ key: str, lora_weights: dict
478
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]:
479
+ """
480
+ Get LoRA weights for a specific key.
481
+
482
+ Args:
483
+ key (str): The key to look up in the LoRA weights.
484
+ lora_weights (dict): Dictionary containing LoRA weights.
485
+
486
+ Returns:
487
+ Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: A tuple containing lora_A, lora_B, and alpha if found, None otherwise.
488
+ """
489
+ prefix = key.split(".lora")[0]
490
+ lora_A = lora_weights.get(f"{prefix}.lora_A.weight")
491
+ lora_B = lora_weights.get(f"{prefix}.lora_B.weight")
492
+ alpha = lora_weights.get(f"{prefix}.alpha")
493
+
494
+ if lora_A is None or lora_B is None:
495
+ return None
496
+ return lora_A, lora_B, alpha
497
+
498
+
499
+ def get_module_for_key(
500
+ key: str, model: Flux
501
+ ) -> F8Linear | torch.nn.Linear | CublasLinear:
502
+ parts = key.split(".")
503
+ module = model
504
+ for part in parts:
505
+ module = getattr(module, part)
506
+ return module
507
+
508
+
509
+ def calculate_lora_weight(
510
+ lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
511
+ rank: Optional[int] = None,
512
+ lora_scale: float = 1.0,
513
+ device: Optional[Union[torch.device, int, str]] = None,
514
+ ):
515
+ lora_A, lora_B, alpha = lora_weights
516
+ if device is None:
517
+ device = lora_A.device
518
+
519
+ uneven_rank = lora_B.shape[1] != lora_A.shape[0]
520
+ rank_diff = lora_A.shape[0] / lora_B.shape[1]
521
+
522
+ if rank is None:
523
+ rank = lora_B.shape[1]
524
+ if alpha is None:
525
+ alpha = rank
526
+
527
+ dtype = torch.float32
528
+ w_up = lora_A.to(dtype=dtype, device=device)
529
+ w_down = lora_B.to(dtype=dtype, device=device)
530
+
531
+ if alpha != rank:
532
+ w_up = w_up * alpha / rank
533
+ if uneven_rank:
534
+ # Fuse each lora instead of repeat interleave for each individual lora,
535
+ # seems to fuse more correctly.
536
+ fused_lora = torch.zeros(
537
+ (lora_B.shape[0], lora_A.shape[1]), device=device, dtype=dtype
538
+ )
539
+ w_up = w_up.chunk(int(rank_diff), dim=0)
540
+ for w_up_chunk in w_up:
541
+ fused_lora = fused_lora + (lora_scale * torch.mm(w_down, w_up_chunk))
542
+ else:
543
+ fused_lora = lora_scale * torch.mm(w_down, w_up)
544
+ return fused_lora
545
+
546
+
547
+ @torch.inference_mode()
548
+ def unfuse_lora_weight_from_module(
549
+ fused_weight: torch.Tensor,
550
+ lora_weights: dict,
551
+ rank: Optional[int] = None,
552
+ lora_scale: float = 1.0,
553
+ ):
554
+ w_dtype = fused_weight.dtype
555
+ dtype = torch.float32
556
+ device = fused_weight.device
557
+
558
+ fused_weight = fused_weight.to(dtype=dtype, device=device)
559
+ fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
560
+ module_weight = fused_weight - fused_lora
561
+ return module_weight.to(dtype=w_dtype, device=device)
562
+
563
+
564
+ @torch.inference_mode()
565
+ def apply_lora_weight_to_module(
566
+ module_weight: torch.Tensor,
567
+ lora_weights: dict,
568
+ rank: int = None,
569
+ lora_scale: float = 1.0,
570
+ ):
571
+ w_dtype = module_weight.dtype
572
+ dtype = torch.float32
573
+ device = module_weight.device
574
+
575
+ fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
576
+ fused_weight = module_weight.to(dtype=dtype) + fused_lora
577
+ return fused_weight.to(dtype=w_dtype, device=device)
578
+
579
+
580
+ def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
581
+ check_if_starts_with_transformer = [
582
+ k for k in lora_weights.keys() if k.startswith("transformer.")
583
+ ]
584
+ if len(check_if_starts_with_transformer) > 0:
585
+ lora_weights = convert_diffusers_to_flux_transformer_checkpoint(
586
+ lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
587
+ )
588
+ else:
589
+ lora_weights = convert_from_original_flux_checkpoint(lora_weights)
590
+ logger.info("LoRA weights loaded")
591
+ logger.debug("Extracting keys")
592
+ keys_without_ab = list(
593
+ set(
594
+ [
595
+ key.replace(".lora_A.weight", "")
596
+ .replace(".lora_B.weight", "")
597
+ .replace(".lora_A", "")
598
+ .replace(".lora_B", "")
599
+ .replace(".alpha", "")
600
+ for key in lora_weights.keys()
601
+ ]
602
+ )
603
+ )
604
+ logger.debug("Keys extracted")
605
+ return keys_without_ab, lora_weights
606
+
607
+
608
+ def get_lora_weights(lora_path: str | StateDict):
609
+ if isinstance(lora_path, (dict, LoraWeights)):
610
+ return lora_path, True
611
+ else:
612
+ return load_file(lora_path, "cpu"), False
613
+
614
+
615
+ def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]):
616
+ dtype = linear.weight.dtype
617
+ weight_is_f8 = False
618
+ if isinstance(linear, F8Linear):
619
+ weight_is_f8 = True
620
+ weight = (
621
+ linear.float8_data.clone()
622
+ .detach()
623
+ .float()
624
+ .mul(linear.scale_reciprocal)
625
+ .to(linear.weight.device)
626
+ )
627
+ elif isinstance(linear, torch.nn.Linear):
628
+ weight = linear.weight.clone().detach().float()
629
+ elif isinstance(linear, CublasLinear) and CublasLinear != type(None):
630
+ weight = linear.weight.clone().detach().float()
631
+ return weight, weight_is_f8, dtype
632
+
633
+
634
+ @torch.inference_mode()
635
+ def apply_lora_to_model(
636
+ model: Flux,
637
+ lora_path: str | StateDict,
638
+ lora_scale: float = 1.0,
639
+ return_lora_resolved: bool = False,
640
+ ) -> Flux:
641
+ has_guidance = model.params.guidance_embed
642
+ logger.info(f"Loading LoRA weights for {lora_path}")
643
+ lora_weights, already_loaded = get_lora_weights(lora_path)
644
+
645
+ if not already_loaded:
646
+ keys_without_ab, lora_weights = resolve_lora_state_dict(
647
+ lora_weights, has_guidance
648
+ )
649
+ elif isinstance(lora_weights, LoraWeights):
650
+ b_ = lora_weights
651
+ lora_weights = b_.weights
652
+ keys_without_ab = list(
653
+ set(
654
+ [
655
+ key.replace(".lora_A.weight", "")
656
+ .replace(".lora_B.weight", "")
657
+ .replace(".lora_A", "")
658
+ .replace(".lora_B", "")
659
+ .replace(".alpha", "")
660
+ for key in lora_weights.keys()
661
+ ]
662
+ )
663
+ )
664
+ else:
665
+ lora_weights = lora_weights
666
+ keys_without_ab = list(
667
+ set(
668
+ [
669
+ key.replace(".lora_A.weight", "")
670
+ .replace(".lora_B.weight", "")
671
+ .replace(".lora_A", "")
672
+ .replace(".lora_B", "")
673
+ .replace(".alpha", "")
674
+ for key in lora_weights.keys()
675
+ ]
676
+ )
677
+ )
678
+ for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
679
+ module = get_module_for_key(key, model)
680
+ weight, is_f8, dtype = extract_weight_from_linear(module)
681
+ lora_sd = get_lora_for_key(key, lora_weights)
682
+ if lora_sd is None:
683
+ # Skipping LoRA application for this module
684
+ continue
685
+ weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
686
+ if is_f8:
687
+ module.set_weight_tensor(weight.type(dtype))
688
+ else:
689
+ module.weight.data = weight.type(dtype)
690
+ logger.success("Lora applied")
691
+ if return_lora_resolved:
692
+ return model, lora_weights
693
+ return model
694
+
695
+
696
+ def remove_lora_from_module(
697
+ model: Flux,
698
+ lora_path: str | StateDict,
699
+ lora_scale: float = 1.0,
700
+ ):
701
+ has_guidance = model.params.guidance_embed
702
+ logger.info(f"Loading LoRA weights for {lora_path}")
703
+ lora_weights, already_loaded = get_lora_weights(lora_path)
704
+
705
+ if not already_loaded:
706
+ keys_without_ab, lora_weights = resolve_lora_state_dict(
707
+ lora_weights, has_guidance
708
+ )
709
+ elif isinstance(lora_weights, LoraWeights):
710
+ b_ = lora_weights
711
+ lora_weights = b_.weights
712
+ keys_without_ab = list(
713
+ set(
714
+ [
715
+ key.replace(".lora_A.weight", "")
716
+ .replace(".lora_B.weight", "")
717
+ .replace(".lora_A", "")
718
+ .replace(".lora_B", "")
719
+ .replace(".alpha", "")
720
+ for key in lora_weights.keys()
721
+ ]
722
+ )
723
+ )
724
+ lora_scale = b_.scale
725
+ else:
726
+ lora_weights = lora_weights
727
+ keys_without_ab = list(
728
+ set(
729
+ [
730
+ key.replace(".lora_A.weight", "")
731
+ .replace(".lora_B.weight", "")
732
+ .replace(".lora_A", "")
733
+ .replace(".lora_B", "")
734
+ .replace(".alpha", "")
735
+ for key in lora_weights.keys()
736
+ ]
737
+ )
738
+ )
739
+
740
+ for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
741
+ module = get_module_for_key(key, model)
742
+ weight, is_f8, dtype = extract_weight_from_linear(module)
743
+ lora_sd = get_lora_for_key(key, lora_weights)
744
+ if lora_sd is None:
745
+ # Skipping LoRA application for this module
746
+ continue
747
+ weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
748
+ if is_f8:
749
+ module.set_weight_tensor(weight.type(dtype))
750
+ else:
751
+ module.weight.data = weight.type(dtype)
752
+ logger.success("Lora unfused")
753
+ return model
main.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import uvicorn
3
+ from api import app
4
+
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description="Launch Flux API server")
8
+ parser.add_argument(
9
+ "-c",
10
+ "--config-path",
11
+ type=str,
12
+ help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
13
+ )
14
+ parser.add_argument(
15
+ "-p",
16
+ "--port",
17
+ type=int,
18
+ default=8088,
19
+ help="Port to run the server on",
20
+ )
21
+ parser.add_argument(
22
+ "-H",
23
+ "--host",
24
+ type=str,
25
+ default="0.0.0.0",
26
+ help="Host to run the server on",
27
+ )
28
+ parser.add_argument(
29
+ "-f", "--flow-model-path", type=str, help="Path to the flow model"
30
+ )
31
+ parser.add_argument(
32
+ "-t", "--text-enc-path", type=str, help="Path to the text encoder"
33
+ )
34
+ parser.add_argument(
35
+ "-a", "--autoencoder-path", type=str, help="Path to the autoencoder"
36
+ )
37
+ parser.add_argument(
38
+ "-m",
39
+ "--model-version",
40
+ type=str,
41
+ choices=["flux-dev", "flux-schnell"],
42
+ default="flux-dev",
43
+ help="Choose model version",
44
+ )
45
+ parser.add_argument(
46
+ "-F",
47
+ "--flux-device",
48
+ type=str,
49
+ default="cuda:0",
50
+ help="Device to run the flow model on",
51
+ )
52
+ parser.add_argument(
53
+ "-T",
54
+ "--text-enc-device",
55
+ type=str,
56
+ default="cuda:0",
57
+ help="Device to run the text encoder on",
58
+ )
59
+ parser.add_argument(
60
+ "-A",
61
+ "--autoencoder-device",
62
+ type=str,
63
+ default="cuda:0",
64
+ help="Device to run the autoencoder on",
65
+ )
66
+ parser.add_argument(
67
+ "-q",
68
+ "--num-to-quant",
69
+ type=int,
70
+ default=20,
71
+ help="Number of linear layers in flow transformer (the 'unet') to quantize",
72
+ )
73
+ parser.add_argument(
74
+ "-C",
75
+ "--compile",
76
+ action="store_true",
77
+ default=False,
78
+ help="Compile the flow model with extra optimizations",
79
+ )
80
+ parser.add_argument(
81
+ "-qT",
82
+ "--quant-text-enc",
83
+ type=str,
84
+ default="qfloat8",
85
+ choices=["qint4", "qfloat8", "qint2", "qint8", "bf16"],
86
+ help="Quantize the t5 text encoder to the given dtype, if bf16, will not quantize",
87
+ dest="quant_text_enc",
88
+ )
89
+ parser.add_argument(
90
+ "-qA",
91
+ "--quant-ae",
92
+ action="store_true",
93
+ default=False,
94
+ help="Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16",
95
+ dest="quant_ae",
96
+ )
97
+ parser.add_argument(
98
+ "-OF",
99
+ "--offload-flow",
100
+ action="store_true",
101
+ default=False,
102
+ dest="offload_flow",
103
+ help="Offload the flow model to the CPU when not being used to save memory",
104
+ )
105
+ parser.add_argument(
106
+ "-OA",
107
+ "--no-offload-ae",
108
+ action="store_false",
109
+ default=True,
110
+ dest="offload_ae",
111
+ help="Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed",
112
+ )
113
+ parser.add_argument(
114
+ "-OT",
115
+ "--no-offload-text-enc",
116
+ action="store_false",
117
+ default=True,
118
+ dest="offload_text_enc",
119
+ help="Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed",
120
+ )
121
+ parser.add_argument(
122
+ "-PF",
123
+ "--prequantized-flow",
124
+ action="store_true",
125
+ default=False,
126
+ dest="prequantized_flow",
127
+ help="Load the flow model from a prequantized checkpoint "
128
+ + "(requires loading the flow model, running a minimum of 24 steps, "
129
+ + "and then saving the state_dict as a safetensors file), "
130
+ + "which reduces the size of the checkpoint by about 50% & reduces startup time",
131
+ )
132
+ parser.add_argument(
133
+ "-nqfm",
134
+ "--no-quantize-flow-modulation",
135
+ action="store_false",
136
+ default=True,
137
+ dest="quantize_modulation",
138
+ help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
139
+ )
140
+ parser.add_argument(
141
+ "-qfl",
142
+ "--quantize-flow-embedder-layers",
143
+ action="store_true",
144
+ default=False,
145
+ dest="quantize_flow_embedder_layers",
146
+ help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
147
+ )
148
+ return parser.parse_args()
149
+
150
+
151
+ def main():
152
+ args = parse_args()
153
+
154
+ # lazy loading so cli returns fast instead of waiting for torch to load modules
155
+ from flux_pipeline import FluxPipeline
156
+ from util import load_config, ModelVersion
157
+
158
+ if args.config_path:
159
+ app.state.model = FluxPipeline.load_pipeline_from_config_path(
160
+ args.config_path, flow_model_path=args.flow_model_path
161
+ )
162
+ else:
163
+ model_version = (
164
+ ModelVersion.flux_dev
165
+ if args.model_version == "flux-dev"
166
+ else ModelVersion.flux_schnell
167
+ )
168
+ config = load_config(
169
+ model_version,
170
+ flux_path=args.flow_model_path,
171
+ flux_device=args.flux_device,
172
+ ae_path=args.autoencoder_path,
173
+ ae_device=args.autoencoder_device,
174
+ text_enc_path=args.text_enc_path,
175
+ text_enc_device=args.text_enc_device,
176
+ flow_dtype="float16",
177
+ text_enc_dtype="bfloat16",
178
+ ae_dtype="bfloat16",
179
+ num_to_quant=args.num_to_quant,
180
+ compile_extras=args.compile,
181
+ compile_blocks=args.compile,
182
+ quant_text_enc=(
183
+ None if args.quant_text_enc == "bf16" else args.quant_text_enc
184
+ ),
185
+ quant_ae=args.quant_ae,
186
+ offload_flow=args.offload_flow,
187
+ offload_ae=args.offload_ae,
188
+ offload_text_enc=args.offload_text_enc,
189
+ prequantized_flow=args.prequantized_flow,
190
+ quantize_modulation=args.quantize_modulation,
191
+ quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
192
+ )
193
+ app.state.model = FluxPipeline.load_pipeline_from_config(config)
194
+
195
+ uvicorn.run(app, host=args.host, port=args.port)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()
modules/autoencoder.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor, nn
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class AutoEncoderParams(BaseModel):
8
+ resolution: int
9
+ in_channels: int
10
+ ch: int
11
+ out_ch: int
12
+ ch_mult: list[int]
13
+ num_res_blocks: int
14
+ z_channels: int
15
+ scale_factor: float
16
+ shift_factor: float
17
+
18
+
19
+ def swish(x: Tensor) -> Tensor:
20
+ return x * torch.sigmoid(x)
21
+
22
+
23
+ class AttnBlock(nn.Module):
24
+ def __init__(self, in_channels: int):
25
+ super().__init__()
26
+ self.in_channels = in_channels
27
+
28
+ self.norm = nn.GroupNorm(
29
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
30
+ )
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(
63
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
64
+ )
65
+ self.conv1 = nn.Conv2d(
66
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
67
+ )
68
+ self.norm2 = nn.GroupNorm(
69
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
70
+ )
71
+ self.conv2 = nn.Conv2d(
72
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
73
+ )
74
+ if self.in_channels != self.out_channels:
75
+ self.nin_shortcut = nn.Conv2d(
76
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
77
+ )
78
+
79
+ def forward(self, x):
80
+ h = x
81
+ h = self.norm1(h)
82
+ h = swish(h)
83
+ h = self.conv1(h)
84
+
85
+ h = self.norm2(h)
86
+ h = swish(h)
87
+ h = self.conv2(h)
88
+
89
+ if self.in_channels != self.out_channels:
90
+ x = self.nin_shortcut(x)
91
+
92
+ return x + h
93
+
94
+
95
+ class Downsample(nn.Module):
96
+ def __init__(self, in_channels: int):
97
+ super().__init__()
98
+ # no asymmetric padding in torch conv, must do it ourselves
99
+ self.conv = nn.Conv2d(
100
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
101
+ )
102
+
103
+ def forward(self, x: Tensor):
104
+ pad = (0, 1, 0, 1)
105
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
106
+ x = self.conv(x)
107
+ return x
108
+
109
+
110
+ class Upsample(nn.Module):
111
+ def __init__(self, in_channels: int):
112
+ super().__init__()
113
+ self.conv = nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
115
+ )
116
+
117
+ def forward(self, x: Tensor):
118
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
119
+ x = self.conv(x)
120
+ return x
121
+
122
+
123
+ class Encoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ resolution: int,
127
+ in_channels: int,
128
+ ch: int,
129
+ ch_mult: list[int],
130
+ num_res_blocks: int,
131
+ z_channels: int,
132
+ ):
133
+ super().__init__()
134
+ self.ch = ch
135
+ self.num_resolutions = len(ch_mult)
136
+ self.num_res_blocks = num_res_blocks
137
+ self.resolution = resolution
138
+ self.in_channels = in_channels
139
+ # downsampling
140
+ self.conv_in = nn.Conv2d(
141
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
142
+ )
143
+
144
+ curr_res = resolution
145
+ in_ch_mult = (1,) + tuple(ch_mult)
146
+ self.in_ch_mult = in_ch_mult
147
+ self.down = nn.ModuleList()
148
+ block_in = self.ch
149
+ for i_level in range(self.num_resolutions):
150
+ block = nn.ModuleList()
151
+ attn = nn.ModuleList()
152
+ block_in = ch * in_ch_mult[i_level]
153
+ block_out = ch * ch_mult[i_level]
154
+ for _ in range(self.num_res_blocks):
155
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
156
+ block_in = block_out
157
+ down = nn.Module()
158
+ down.block = block
159
+ down.attn = attn
160
+ if i_level != self.num_resolutions - 1:
161
+ down.downsample = Downsample(block_in)
162
+ curr_res = curr_res // 2
163
+ self.down.append(down)
164
+
165
+ # middle
166
+ self.mid = nn.Module()
167
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
168
+ self.mid.attn_1 = AttnBlock(block_in)
169
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
170
+
171
+ # end
172
+ self.norm_out = nn.GroupNorm(
173
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
174
+ )
175
+ self.conv_out = nn.Conv2d(
176
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
177
+ )
178
+
179
+ def forward(self, x: Tensor) -> Tensor:
180
+ # downsampling
181
+ hs = [self.conv_in(x)]
182
+ for i_level in range(self.num_resolutions):
183
+ for i_block in range(self.num_res_blocks):
184
+ h = self.down[i_level].block[i_block](hs[-1])
185
+ if len(self.down[i_level].attn) > 0:
186
+ h = self.down[i_level].attn[i_block](h)
187
+ hs.append(h)
188
+ if i_level != self.num_resolutions - 1:
189
+ hs.append(self.down[i_level].downsample(hs[-1]))
190
+
191
+ # middle
192
+ h = hs[-1]
193
+ h = self.mid.block_1(h)
194
+ h = self.mid.attn_1(h)
195
+ h = self.mid.block_2(h)
196
+ # end
197
+ h = self.norm_out(h)
198
+ h = swish(h)
199
+ h = self.conv_out(h)
200
+ return h
201
+
202
+
203
+ class Decoder(nn.Module):
204
+ def __init__(
205
+ self,
206
+ ch: int,
207
+ out_ch: int,
208
+ ch_mult: list[int],
209
+ num_res_blocks: int,
210
+ in_channels: int,
211
+ resolution: int,
212
+ z_channels: int,
213
+ ):
214
+ super().__init__()
215
+ self.ch = ch
216
+ self.num_resolutions = len(ch_mult)
217
+ self.num_res_blocks = num_res_blocks
218
+ self.resolution = resolution
219
+ self.in_channels = in_channels
220
+ self.ffactor = 2 ** (self.num_resolutions - 1)
221
+
222
+ # compute in_ch_mult, block_in and curr_res at lowest res
223
+ block_in = ch * ch_mult[self.num_resolutions - 1]
224
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
225
+ self.z_shape = (1, z_channels, curr_res, curr_res)
226
+
227
+ # z to block_in
228
+ self.conv_in = nn.Conv2d(
229
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
230
+ )
231
+
232
+ # middle
233
+ self.mid = nn.Module()
234
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
235
+ self.mid.attn_1 = AttnBlock(block_in)
236
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
237
+
238
+ # upsampling
239
+ self.up = nn.ModuleList()
240
+ for i_level in reversed(range(self.num_resolutions)):
241
+ block = nn.ModuleList()
242
+ attn = nn.ModuleList()
243
+ block_out = ch * ch_mult[i_level]
244
+ for _ in range(self.num_res_blocks + 1):
245
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
246
+ block_in = block_out
247
+ up = nn.Module()
248
+ up.block = block
249
+ up.attn = attn
250
+ if i_level != 0:
251
+ up.upsample = Upsample(block_in)
252
+ curr_res = curr_res * 2
253
+ self.up.insert(0, up) # prepend to get consistent order
254
+
255
+ # end
256
+ self.norm_out = nn.GroupNorm(
257
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
258
+ )
259
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
260
+
261
+ def forward(self, z: Tensor) -> Tensor:
262
+ # z to block_in
263
+ h = self.conv_in(z)
264
+
265
+ # middle
266
+ h = self.mid.block_1(h)
267
+ h = self.mid.attn_1(h)
268
+ h = self.mid.block_2(h)
269
+
270
+ # upsampling
271
+ for i_level in reversed(range(self.num_resolutions)):
272
+ for i_block in range(self.num_res_blocks + 1):
273
+ h = self.up[i_level].block[i_block](h)
274
+ if len(self.up[i_level].attn) > 0:
275
+ h = self.up[i_level].attn[i_block](h)
276
+ if i_level != 0:
277
+ h = self.up[i_level].upsample(h)
278
+
279
+ # end
280
+ h = self.norm_out(h)
281
+ h = swish(h)
282
+ h = self.conv_out(h)
283
+ return h
284
+
285
+
286
+ class DiagonalGaussian(nn.Module):
287
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
288
+ super().__init__()
289
+ self.sample = sample
290
+ self.chunk_dim = chunk_dim
291
+
292
+ def forward(self, z: Tensor) -> Tensor:
293
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
294
+ if self.sample:
295
+ std = torch.exp(0.5 * logvar)
296
+ return mean + std * torch.randn_like(mean)
297
+ else:
298
+ return mean
299
+
300
+
301
+ class AutoEncoder(nn.Module):
302
+ def __init__(self, params: AutoEncoderParams):
303
+ super().__init__()
304
+ self.encoder = Encoder(
305
+ resolution=params.resolution,
306
+ in_channels=params.in_channels,
307
+ ch=params.ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.decoder = Decoder(
313
+ resolution=params.resolution,
314
+ in_channels=params.in_channels,
315
+ ch=params.ch,
316
+ out_ch=params.out_ch,
317
+ ch_mult=params.ch_mult,
318
+ num_res_blocks=params.num_res_blocks,
319
+ z_channels=params.z_channels,
320
+ )
321
+ self.reg = DiagonalGaussian()
322
+
323
+ self.scale_factor = params.scale_factor
324
+ self.shift_factor = params.shift_factor
325
+
326
+ def encode(self, x: Tensor) -> Tensor:
327
+ z = self.reg(self.encoder(x))
328
+ z = self.scale_factor * (z - self.shift_factor)
329
+ return z
330
+
331
+ def decode(self, z: Tensor) -> Tensor:
332
+ z = z / self.scale_factor + self.shift_factor
333
+ return self.decoder(z)
334
+
335
+ def forward(self, x: Tensor) -> Tensor:
336
+ return self.decode(self.encode(x))
modules/conditioner.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from transformers import (
6
+ CLIPTextModel,
7
+ CLIPTokenizer,
8
+ T5EncoderModel,
9
+ T5Tokenizer,
10
+ __version__,
11
+ )
12
+ from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig
13
+
14
+ CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
15
+
16
+
17
+ def auto_quantization_config(
18
+ quantization_dtype: str,
19
+ ) -> QuantoConfig | BitsAndBytesConfig:
20
+ if quantization_dtype == "qfloat8":
21
+ return QuantoConfig(weights="float8")
22
+ elif quantization_dtype == "qint4":
23
+ return BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_compute_dtype=torch.bfloat16,
26
+ bnb_4bit_quant_type="nf4",
27
+ )
28
+ elif quantization_dtype == "qint8":
29
+ return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
30
+ elif quantization_dtype == "qint2":
31
+ return QuantoConfig(weights="int2")
32
+ elif quantization_dtype is None or quantization_dtype == "bfloat16":
33
+ return None
34
+ else:
35
+ raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
36
+
37
+
38
+ class HFEmbedder(nn.Module):
39
+ def __init__(
40
+ self,
41
+ version: str,
42
+ max_length: int,
43
+ device: torch.device | int,
44
+ quantization_dtype: str | None = None,
45
+ offloading_device: torch.device | int | None = torch.device("cpu"),
46
+ is_clip: bool = False,
47
+ **hf_kwargs,
48
+ ):
49
+ super().__init__()
50
+ self.offloading_device = (
51
+ offloading_device
52
+ if isinstance(offloading_device, torch.device)
53
+ else torch.device(offloading_device)
54
+ )
55
+ self.device = (
56
+ device if isinstance(device, torch.device) else torch.device(device)
57
+ )
58
+ self.is_clip = version.startswith("openai") or is_clip
59
+ self.max_length = max_length
60
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
61
+
62
+ auto_quant_config = (
63
+ auto_quantization_config(quantization_dtype)
64
+ if quantization_dtype is not None
65
+ and quantization_dtype != "bfloat16"
66
+ and quantization_dtype != "float16"
67
+ else None
68
+ )
69
+
70
+ # BNB will move to cuda:0 by default if not specified
71
+ if isinstance(auto_quant_config, BitsAndBytesConfig):
72
+ hf_kwargs["device_map"] = {"": self.device.index}
73
+ if auto_quant_config is not None:
74
+ hf_kwargs["quantization_config"] = auto_quant_config
75
+
76
+ if self.is_clip:
77
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
78
+ version, max_length=max_length
79
+ )
80
+
81
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
82
+ version,
83
+ **hf_kwargs,
84
+ )
85
+
86
+ else:
87
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
88
+ version, max_length=max_length
89
+ )
90
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
91
+ version,
92
+ **hf_kwargs,
93
+ )
94
+
95
+ def offload(self):
96
+ self.hf_module.to(device=self.offloading_device)
97
+ torch.cuda.empty_cache()
98
+
99
+ def cuda(self):
100
+ self.hf_module.to(device=self.device)
101
+
102
+ def forward(self, text: list[str]) -> Tensor:
103
+ batch_encoding = self.tokenizer(
104
+ text,
105
+ truncation=True,
106
+ max_length=self.max_length,
107
+ return_length=False,
108
+ return_overflowing_tokens=False,
109
+ padding="max_length",
110
+ return_tensors="pt",
111
+ )
112
+ outputs = self.hf_module(
113
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
114
+ attention_mask=None,
115
+ output_hidden_states=False,
116
+ )
117
+ return outputs[self.output_key]
118
+
119
+
120
+ if __name__ == "__main__":
121
+ model = HFEmbedder(
122
+ "city96/t5-v1_1-xxl-encoder-bf16",
123
+ max_length=512,
124
+ device=0,
125
+ quantization_dtype="qfloat8",
126
+ )
127
+ o = model(["hello"])
128
+ print(o)
modules/flux_model.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import namedtuple
3
+ from typing import TYPE_CHECKING, List
4
+
5
+ import torch
6
+ from loguru import logger
7
+
8
+ if TYPE_CHECKING:
9
+ from lora_loading import LoraWeights
10
+ from util import ModelSpec
11
+ DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ torch.backends.cudnn.allow_tf32 = True
14
+ torch.backends.cudnn.benchmark = True
15
+ torch.backends.cudnn.benchmark_limit = 20
16
+ torch.set_float32_matmul_precision("high")
17
+ import math
18
+
19
+ from pydantic import BaseModel
20
+ from torch import Tensor, nn
21
+ from torch.nn import functional as F
22
+
23
+
24
+ class FluxParams(BaseModel):
25
+ in_channels: int
26
+ vec_in_dim: int
27
+ context_in_dim: int
28
+ hidden_size: int
29
+ mlp_ratio: float
30
+ num_heads: int
31
+ depth: int
32
+ depth_single_blocks: int
33
+ axes_dim: list[int]
34
+ theta: int
35
+ qkv_bias: bool
36
+ guidance_embed: bool
37
+
38
+
39
+ # attention is always same shape each time it's called per H*W, so compile with fullgraph
40
+ # @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
41
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
42
+ q, k = apply_rope(q, k, pe)
43
+ x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
44
+ x = x.reshape(*x.shape[:-2], -1)
45
+ return x
46
+
47
+
48
+ # @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
49
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
+ omega = 1.0 / (theta**scale)
52
+ out = torch.einsum("...n,d->...nd", pos, omega)
53
+ out = torch.stack(
54
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
55
+ )
56
+ out = out.reshape(*out.shape[:-1], 2, 2)
57
+ return out
58
+
59
+
60
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
61
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
62
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
63
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
64
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
65
+ return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
66
+
67
+
68
+ class EmbedND(nn.Module):
69
+ def __init__(
70
+ self,
71
+ dim: int,
72
+ theta: int,
73
+ axes_dim: list[int],
74
+ dtype: torch.dtype = torch.bfloat16,
75
+ ):
76
+ super().__init__()
77
+ self.dim = dim
78
+ self.theta = theta
79
+ self.axes_dim = axes_dim
80
+ self.dtype = dtype
81
+
82
+ def forward(self, ids: Tensor) -> Tensor:
83
+ n_axes = ids.shape[-1]
84
+ emb = torch.cat(
85
+ [
86
+ rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
87
+ for i in range(n_axes)
88
+ ],
89
+ dim=-3,
90
+ )
91
+
92
+ return emb.unsqueeze(1)
93
+
94
+
95
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
96
+ """
97
+ Create sinusoidal timestep embeddings.
98
+ :param t: a 1-D Tensor of N indices, one per batch element.
99
+ These may be fractional.
100
+ :param dim: the dimension of the output.
101
+ :param max_period: controls the minimum frequency of the embeddings.
102
+ :return: an (N, D) Tensor of positional embeddings.
103
+ """
104
+ t = time_factor * t
105
+ half = dim // 2
106
+ freqs = torch.exp(
107
+ -math.log(max_period)
108
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
109
+ / half
110
+ )
111
+
112
+ args = t[:, None].float() * freqs[None]
113
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
+ if dim % 2:
115
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
+ return embedding
117
+
118
+
119
+ class MLPEmbedder(nn.Module):
120
+ def __init__(
121
+ self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False
122
+ ):
123
+ from float8_quantize import F8Linear
124
+
125
+ super().__init__()
126
+ self.in_layer = (
127
+ nn.Linear(in_dim, hidden_dim, bias=True)
128
+ if not prequantized
129
+ else (
130
+ F8Linear(
131
+ in_features=in_dim,
132
+ out_features=hidden_dim,
133
+ bias=True,
134
+ )
135
+ if quantized
136
+ else nn.Linear(in_dim, hidden_dim, bias=True)
137
+ )
138
+ )
139
+ self.silu = nn.SiLU()
140
+ self.out_layer = (
141
+ nn.Linear(hidden_dim, hidden_dim, bias=True)
142
+ if not prequantized
143
+ else (
144
+ F8Linear(
145
+ in_features=hidden_dim,
146
+ out_features=hidden_dim,
147
+ bias=True,
148
+ )
149
+ if quantized
150
+ else nn.Linear(hidden_dim, hidden_dim, bias=True)
151
+ )
152
+ )
153
+
154
+ def forward(self, x: Tensor) -> Tensor:
155
+ return self.out_layer(self.silu(self.in_layer(x)))
156
+
157
+
158
+ class RMSNorm(torch.nn.Module):
159
+ def __init__(self, dim: int):
160
+ super().__init__()
161
+ self.scale = nn.Parameter(torch.ones(dim))
162
+
163
+ def forward(self, x: Tensor):
164
+ return F.rms_norm(x.float(), self.scale.shape, self.scale, eps=1e-6).to(x)
165
+
166
+
167
+ class QKNorm(torch.nn.Module):
168
+ def __init__(self, dim: int):
169
+ super().__init__()
170
+ self.query_norm = RMSNorm(dim)
171
+ self.key_norm = RMSNorm(dim)
172
+
173
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
174
+ q = self.query_norm(q)
175
+ k = self.key_norm(k)
176
+ return q, k
177
+
178
+
179
+ class SelfAttention(nn.Module):
180
+ def __init__(
181
+ self,
182
+ dim: int,
183
+ num_heads: int = 8,
184
+ qkv_bias: bool = False,
185
+ prequantized: bool = False,
186
+ ):
187
+ super().__init__()
188
+ from float8_quantize import F8Linear
189
+
190
+ self.num_heads = num_heads
191
+ head_dim = dim // num_heads
192
+
193
+ self.qkv = (
194
+ nn.Linear(dim, dim * 3, bias=qkv_bias)
195
+ if not prequantized
196
+ else F8Linear(
197
+ in_features=dim,
198
+ out_features=dim * 3,
199
+ bias=qkv_bias,
200
+ )
201
+ )
202
+ self.norm = QKNorm(head_dim)
203
+ self.proj = (
204
+ nn.Linear(dim, dim)
205
+ if not prequantized
206
+ else F8Linear(
207
+ in_features=dim,
208
+ out_features=dim,
209
+ bias=True,
210
+ )
211
+ )
212
+ self.K = 3
213
+ self.H = self.num_heads
214
+ self.KH = self.K * self.H
215
+
216
+ def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
217
+ B, L, D = x.shape
218
+ q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
219
+ return q, k, v
220
+
221
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
222
+ qkv = self.qkv(x)
223
+ q, k, v = self.rearrange_for_norm(qkv)
224
+ q, k = self.norm(q, k, v)
225
+ x = attention(q, k, v, pe=pe)
226
+ x = self.proj(x)
227
+ return x
228
+
229
+
230
+ ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
231
+
232
+
233
+ class Modulation(nn.Module):
234
+ def __init__(self, dim: int, double: bool, quantized_modulation: bool = False):
235
+ super().__init__()
236
+ from float8_quantize import F8Linear
237
+
238
+ self.is_double = double
239
+ self.multiplier = 6 if double else 3
240
+ self.lin = (
241
+ nn.Linear(dim, self.multiplier * dim, bias=True)
242
+ if not quantized_modulation
243
+ else F8Linear(
244
+ in_features=dim,
245
+ out_features=self.multiplier * dim,
246
+ bias=True,
247
+ )
248
+ )
249
+ self.act = nn.SiLU()
250
+
251
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
252
+ out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
253
+
254
+ return (
255
+ ModulationOut(*out[:3]),
256
+ ModulationOut(*out[3:]) if self.is_double else None,
257
+ )
258
+
259
+
260
+ class DoubleStreamBlock(nn.Module):
261
+ def __init__(
262
+ self,
263
+ hidden_size: int,
264
+ num_heads: int,
265
+ mlp_ratio: float,
266
+ qkv_bias: bool = False,
267
+ dtype: torch.dtype = torch.float16,
268
+ quantized_modulation: bool = False,
269
+ prequantized: bool = False,
270
+ ):
271
+ super().__init__()
272
+ from float8_quantize import F8Linear
273
+
274
+ self.dtype = dtype
275
+
276
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
277
+ self.num_heads = num_heads
278
+ self.hidden_size = hidden_size
279
+ self.img_mod = Modulation(
280
+ hidden_size, double=True, quantized_modulation=quantized_modulation
281
+ )
282
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
283
+ self.img_attn = SelfAttention(
284
+ dim=hidden_size,
285
+ num_heads=num_heads,
286
+ qkv_bias=qkv_bias,
287
+ prequantized=prequantized,
288
+ )
289
+
290
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
291
+ self.img_mlp = nn.Sequential(
292
+ (
293
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
294
+ if not prequantized
295
+ else F8Linear(
296
+ in_features=hidden_size,
297
+ out_features=mlp_hidden_dim,
298
+ bias=True,
299
+ )
300
+ ),
301
+ nn.GELU(approximate="tanh"),
302
+ (
303
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
304
+ if not prequantized
305
+ else F8Linear(
306
+ in_features=mlp_hidden_dim,
307
+ out_features=hidden_size,
308
+ bias=True,
309
+ )
310
+ ),
311
+ )
312
+
313
+ self.txt_mod = Modulation(
314
+ hidden_size, double=True, quantized_modulation=quantized_modulation
315
+ )
316
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
317
+ self.txt_attn = SelfAttention(
318
+ dim=hidden_size,
319
+ num_heads=num_heads,
320
+ qkv_bias=qkv_bias,
321
+ prequantized=prequantized,
322
+ )
323
+
324
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
325
+ self.txt_mlp = nn.Sequential(
326
+ (
327
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
328
+ if not prequantized
329
+ else F8Linear(
330
+ in_features=hidden_size,
331
+ out_features=mlp_hidden_dim,
332
+ bias=True,
333
+ )
334
+ ),
335
+ nn.GELU(approximate="tanh"),
336
+ (
337
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
338
+ if not prequantized
339
+ else F8Linear(
340
+ in_features=mlp_hidden_dim,
341
+ out_features=hidden_size,
342
+ bias=True,
343
+ )
344
+ ),
345
+ )
346
+ self.K = 3
347
+ self.H = self.num_heads
348
+ self.KH = self.K * self.H
349
+ self.do_clamp = dtype == torch.float16
350
+
351
+ def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
352
+ B, L, D = x.shape
353
+ q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
354
+ return q, k, v
355
+
356
+ def forward(
357
+ self,
358
+ img: Tensor,
359
+ txt: Tensor,
360
+ vec: Tensor,
361
+ pe: Tensor,
362
+ ) -> tuple[Tensor, Tensor]:
363
+ img_mod1, img_mod2 = self.img_mod(vec)
364
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
365
+
366
+ # prepare image for attention
367
+ img_modulated = self.img_norm1(img)
368
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
369
+ img_qkv = self.img_attn.qkv(img_modulated)
370
+ img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
371
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
372
+
373
+ # prepare txt for attention
374
+ txt_modulated = self.txt_norm1(txt)
375
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
376
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
377
+ txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
378
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
379
+
380
+ q = torch.cat((txt_q, img_q), dim=2)
381
+ k = torch.cat((txt_k, img_k), dim=2)
382
+ v = torch.cat((txt_v, img_v), dim=2)
383
+
384
+ attn = attention(q, k, v, pe=pe)
385
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
386
+ # calculate the img bloks
387
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
388
+ img = img + img_mod2.gate * self.img_mlp(
389
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
390
+ )
391
+
392
+ # calculate the txt bloks
393
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
394
+ txt = txt + txt_mod2.gate * self.txt_mlp(
395
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
396
+ )
397
+ if self.do_clamp:
398
+ img = img.clamp(min=-32000, max=32000)
399
+ txt = txt.clamp(min=-32000, max=32000)
400
+ return img, txt
401
+
402
+
403
+ class SingleStreamBlock(nn.Module):
404
+ """
405
+ A DiT block with parallel linear layers as described in
406
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ hidden_size: int,
412
+ num_heads: int,
413
+ mlp_ratio: float = 4.0,
414
+ qk_scale: float | None = None,
415
+ dtype: torch.dtype = torch.float16,
416
+ quantized_modulation: bool = False,
417
+ prequantized: bool = False,
418
+ ):
419
+ super().__init__()
420
+ from float8_quantize import F8Linear
421
+
422
+ self.dtype = dtype
423
+ self.hidden_dim = hidden_size
424
+ self.num_heads = num_heads
425
+ head_dim = hidden_size // num_heads
426
+ self.scale = qk_scale or head_dim**-0.5
427
+
428
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
429
+ # qkv and mlp_in
430
+ self.linear1 = (
431
+ nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
432
+ if not prequantized
433
+ else F8Linear(
434
+ in_features=hidden_size,
435
+ out_features=hidden_size * 3 + self.mlp_hidden_dim,
436
+ bias=True,
437
+ )
438
+ )
439
+ # proj and mlp_out
440
+ self.linear2 = (
441
+ nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
442
+ if not prequantized
443
+ else F8Linear(
444
+ in_features=hidden_size + self.mlp_hidden_dim,
445
+ out_features=hidden_size,
446
+ bias=True,
447
+ )
448
+ )
449
+
450
+ self.norm = QKNorm(head_dim)
451
+
452
+ self.hidden_size = hidden_size
453
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
454
+
455
+ self.mlp_act = nn.GELU(approximate="tanh")
456
+ self.modulation = Modulation(
457
+ hidden_size,
458
+ double=False,
459
+ quantized_modulation=quantized_modulation and prequantized,
460
+ )
461
+
462
+ self.K = 3
463
+ self.H = self.num_heads
464
+ self.KH = self.K * self.H
465
+ self.do_clamp = dtype == torch.float16
466
+
467
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
468
+ mod = self.modulation(vec)[0]
469
+ pre_norm = self.pre_norm(x)
470
+ x_mod = (1 + mod.scale) * pre_norm + mod.shift
471
+ qkv, mlp = torch.split(
472
+ self.linear1(x_mod),
473
+ [3 * self.hidden_size, self.mlp_hidden_dim],
474
+ dim=-1,
475
+ )
476
+ B, L, D = qkv.shape
477
+ q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
478
+ q, k = self.norm(q, k, v)
479
+ attn = attention(q, k, v, pe=pe)
480
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
481
+ if self.do_clamp:
482
+ out = (x + mod.gate * output).clamp(min=-32000, max=32000)
483
+ else:
484
+ out = x + mod.gate * output
485
+ return out
486
+
487
+
488
+ class LastLayer(nn.Module):
489
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
490
+ super().__init__()
491
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
492
+ self.linear = nn.Linear(
493
+ hidden_size, patch_size * patch_size * out_channels, bias=True
494
+ )
495
+ self.adaLN_modulation = nn.Sequential(
496
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
497
+ )
498
+
499
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
500
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
501
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
502
+ x = self.linear(x)
503
+ return x
504
+
505
+
506
+ class Flux(nn.Module):
507
+ """
508
+ Transformer model for flow matching on sequences.
509
+ """
510
+
511
+ def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16):
512
+ super().__init__()
513
+
514
+ self.dtype = dtype
515
+ self.params = config.params
516
+ self.in_channels = config.params.in_channels
517
+ self.out_channels = self.in_channels
518
+ self.loras: List[LoraWeights] = []
519
+ prequantized_flow = config.prequantized_flow
520
+ quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
521
+ quantized_modulation = config.quantize_modulation and prequantized_flow
522
+ from float8_quantize import F8Linear
523
+
524
+ if config.params.hidden_size % config.params.num_heads != 0:
525
+ raise ValueError(
526
+ f"Hidden size {config.params.hidden_size} must be divisible by num_heads {config.params.num_heads}"
527
+ )
528
+ pe_dim = config.params.hidden_size // config.params.num_heads
529
+ if sum(config.params.axes_dim) != pe_dim:
530
+ raise ValueError(
531
+ f"Got {config.params.axes_dim} but expected positional dim {pe_dim}"
532
+ )
533
+ self.hidden_size = config.params.hidden_size
534
+ self.num_heads = config.params.num_heads
535
+ self.pe_embedder = EmbedND(
536
+ dim=pe_dim,
537
+ theta=config.params.theta,
538
+ axes_dim=config.params.axes_dim,
539
+ dtype=self.dtype,
540
+ )
541
+ self.img_in = (
542
+ nn.Linear(self.in_channels, self.hidden_size, bias=True)
543
+ if not prequantized_flow
544
+ else (
545
+ F8Linear(
546
+ in_features=self.in_channels,
547
+ out_features=self.hidden_size,
548
+ bias=True,
549
+ )
550
+ if quantized_embedders
551
+ else nn.Linear(self.in_channels, self.hidden_size, bias=True)
552
+ )
553
+ )
554
+ self.time_in = MLPEmbedder(
555
+ in_dim=256,
556
+ hidden_dim=self.hidden_size,
557
+ prequantized=prequantized_flow,
558
+ quantized=quantized_embedders,
559
+ )
560
+ self.vector_in = MLPEmbedder(
561
+ config.params.vec_in_dim,
562
+ self.hidden_size,
563
+ prequantized=prequantized_flow,
564
+ quantized=quantized_embedders,
565
+ )
566
+ self.guidance_in = (
567
+ MLPEmbedder(
568
+ in_dim=256,
569
+ hidden_dim=self.hidden_size,
570
+ prequantized=prequantized_flow,
571
+ quantized=quantized_embedders,
572
+ )
573
+ if config.params.guidance_embed
574
+ else nn.Identity()
575
+ )
576
+ self.txt_in = (
577
+ nn.Linear(config.params.context_in_dim, self.hidden_size)
578
+ if not quantized_embedders
579
+ else (
580
+ F8Linear(
581
+ in_features=config.params.context_in_dim,
582
+ out_features=self.hidden_size,
583
+ bias=True,
584
+ )
585
+ if quantized_embedders
586
+ else nn.Linear(config.params.context_in_dim, self.hidden_size)
587
+ )
588
+ )
589
+
590
+ self.double_blocks = nn.ModuleList(
591
+ [
592
+ DoubleStreamBlock(
593
+ self.hidden_size,
594
+ self.num_heads,
595
+ mlp_ratio=config.params.mlp_ratio,
596
+ qkv_bias=config.params.qkv_bias,
597
+ dtype=self.dtype,
598
+ quantized_modulation=quantized_modulation,
599
+ prequantized=prequantized_flow,
600
+ )
601
+ for _ in range(config.params.depth)
602
+ ]
603
+ )
604
+
605
+ self.single_blocks = nn.ModuleList(
606
+ [
607
+ SingleStreamBlock(
608
+ self.hidden_size,
609
+ self.num_heads,
610
+ mlp_ratio=config.params.mlp_ratio,
611
+ dtype=self.dtype,
612
+ quantized_modulation=quantized_modulation,
613
+ prequantized=prequantized_flow,
614
+ )
615
+ for _ in range(config.params.depth_single_blocks)
616
+ ]
617
+ )
618
+
619
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
620
+
621
+ def get_lora(self, identifier: str):
622
+ for lora in self.loras:
623
+ if lora.path == identifier or lora.name == identifier:
624
+ return lora
625
+
626
+ def has_lora(self, identifier: str):
627
+ for lora in self.loras:
628
+ if lora.path == identifier or lora.name == identifier:
629
+ return True
630
+
631
+ def load_lora(self, path: str, scale: float, name: str = None):
632
+ from lora_loading import (
633
+ LoraWeights,
634
+ apply_lora_to_model,
635
+ remove_lora_from_module,
636
+ )
637
+
638
+ if self.has_lora(path):
639
+ lora = self.get_lora(path)
640
+ if lora.scale == scale:
641
+ logger.warning(
642
+ f"Lora {lora.name} already loaded with same scale - ignoring!"
643
+ )
644
+ else:
645
+ remove_lora_from_module(self, lora, lora.scale)
646
+ apply_lora_to_model(self, lora, scale)
647
+ for idx, lora_ in enumerate(self.loras):
648
+ if lora_.path == lora.path:
649
+ self.loras[idx].scale = scale
650
+ break
651
+ else:
652
+ _, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True)
653
+ self.loras.append(LoraWeights(lora, path, name, scale))
654
+
655
+ def unload_lora(self, path_or_identifier: str):
656
+ from lora_loading import remove_lora_from_module
657
+
658
+ removed = False
659
+ for idx, lora_ in enumerate(list(self.loras)):
660
+ if lora_.path == path_or_identifier or lora_.name == path_or_identifier:
661
+ remove_lora_from_module(self, lora_.weights, lora_.scale)
662
+ self.loras.pop(idx)
663
+ removed = True
664
+ break
665
+ if not removed:
666
+ logger.warning(
667
+ f"Couldn't remove lora {path_or_identifier} as it wasn't found fused to the model!"
668
+ )
669
+ else:
670
+ logger.info("Successfully removed lora from module.")
671
+
672
+ def forward(
673
+ self,
674
+ img: Tensor,
675
+ img_ids: Tensor,
676
+ txt: Tensor,
677
+ txt_ids: Tensor,
678
+ timesteps: Tensor,
679
+ y: Tensor,
680
+ guidance: Tensor | None = None,
681
+ ) -> Tensor:
682
+ if img.ndim != 3 or txt.ndim != 3:
683
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
684
+
685
+ # running on sequences img
686
+ img = self.img_in(img)
687
+ vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
688
+
689
+ if self.params.guidance_embed:
690
+ if guidance is None:
691
+ raise ValueError(
692
+ "Didn't get guidance strength for guidance distilled model."
693
+ )
694
+ vec = vec + self.guidance_in(
695
+ timestep_embedding(guidance, 256).type(self.dtype)
696
+ )
697
+ vec = vec + self.vector_in(y)
698
+
699
+ txt = self.txt_in(txt)
700
+
701
+ ids = torch.cat((txt_ids, img_ids), dim=1)
702
+ pe = self.pe_embedder(ids)
703
+
704
+ # double stream blocks
705
+ for block in self.double_blocks:
706
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
707
+
708
+ img = torch.cat((txt, img), 1)
709
+
710
+ # single stream blocks
711
+ for block in self.single_blocks:
712
+ img = block(img, vec=vec, pe=pe)
713
+
714
+ img = img[:, txt.shape[1] :, ...]
715
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
716
+ return img
717
+
718
+ @classmethod
719
+ def from_pretrained(
720
+ cls: "Flux", path: str, dtype: torch.dtype = torch.float16
721
+ ) -> "Flux":
722
+ from safetensors.torch import load_file
723
+
724
+ from util import load_config_from_path
725
+
726
+ config = load_config_from_path(path)
727
+ with torch.device("meta"):
728
+ klass = cls(config=config, dtype=dtype)
729
+ if not config.prequantized_flow:
730
+ klass.type(dtype)
731
+
732
+ ckpt = load_file(config.ckpt_path, device="cpu")
733
+ klass.load_state_dict(ckpt, assign=True)
734
+ return klass.to("cpu")
photo.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8476d6d93124c5265bb9ff0b393600b7ca26b3a566822c036cd9f59141065a9b
3
+ size 174924704
start.py ADDED
File without changes
util.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Literal, Optional
4
+
5
+ import torch
6
+ from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
+ from modules.conditioner import HFEmbedder
8
+ from modules.flux_model import Flux, FluxParams
9
+ from safetensors.torch import load_file as load_sft
10
+
11
+ try:
12
+ from enum import StrEnum
13
+ except:
14
+ from enum import Enum
15
+
16
+ class StrEnum(str, Enum):
17
+ pass
18
+
19
+
20
+ from pydantic import BaseModel, ConfigDict
21
+ from loguru import logger
22
+
23
+
24
+ class ModelVersion(StrEnum):
25
+ flux_dev = "flux-dev"
26
+ flux_schnell = "flux-schnell"
27
+
28
+
29
+ class QuantizationDtype(StrEnum):
30
+ qfloat8 = "qfloat8"
31
+ qint2 = "qint2"
32
+ qint4 = "qint4"
33
+ qint8 = "qint8"
34
+ bfloat16 = "bfloat16"
35
+ float16 = "float16"
36
+
37
+
38
+ class ModelSpec(BaseModel):
39
+ version: ModelVersion
40
+ params: FluxParams
41
+ ae_params: AutoEncoderParams
42
+ ckpt_path: str | None
43
+ # Add option to pass in custom clip model
44
+ clip_path: str | None = "openai/clip-vit-large-patch14"
45
+ ae_path: str | None
46
+ repo_id: str | None
47
+ repo_flow: str | None
48
+ repo_ae: str | None
49
+ text_enc_max_length: int = 512
50
+ text_enc_path: str | None
51
+ text_enc_device: str | torch.device | None = "cuda:0"
52
+ ae_device: str | torch.device | None = "cuda:0"
53
+ flux_device: str | torch.device | None = "cuda:0"
54
+ flow_dtype: str = "float16"
55
+ ae_dtype: str = "bfloat16"
56
+ text_enc_dtype: str = "bfloat16"
57
+ # unused / deprecated
58
+ num_to_quant: Optional[int] = 20
59
+ quantize_extras: bool = False
60
+ compile_extras: bool = False
61
+ compile_blocks: bool = False
62
+ flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
63
+ text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
64
+ ae_quantization_dtype: Optional[QuantizationDtype] = None
65
+ clip_quantization_dtype: Optional[QuantizationDtype] = None
66
+ offload_text_encoder: bool = False
67
+ offload_vae: bool = False
68
+ offload_flow: bool = False
69
+ prequantized_flow: bool = False
70
+
71
+ # Improved precision via not quanitzing the modulation linear layers
72
+ quantize_modulation: bool = True
73
+ # Improved precision via not quanitzing the flow embedder layers
74
+ quantize_flow_embedder_layers: bool = False
75
+
76
+ model_config: ConfigDict = {
77
+ "arbitrary_types_allowed": True,
78
+ "use_enum_values": True,
79
+ }
80
+
81
+
82
+ def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]:
83
+ flow = load_flow_model(config)
84
+ ae = load_autoencoder(config)
85
+ clip, t5 = load_text_encoders(config)
86
+ return flow, ae, clip, t5
87
+
88
+
89
+ def parse_device(device: str | torch.device | None) -> torch.device:
90
+ if isinstance(device, str):
91
+ return torch.device(device)
92
+ elif isinstance(device, torch.device):
93
+ return device
94
+ else:
95
+ return torch.device("cuda:0")
96
+
97
+
98
+ def into_dtype(dtype: str) -> torch.dtype:
99
+ if isinstance(dtype, torch.dtype):
100
+ return dtype
101
+ if dtype == "float16":
102
+ return torch.float16
103
+ elif dtype == "bfloat16":
104
+ return torch.bfloat16
105
+ elif dtype == "float32":
106
+ return torch.float32
107
+ else:
108
+ raise ValueError(f"Invalid dtype: {dtype}")
109
+
110
+
111
+ def into_device(device: str | torch.device | None) -> torch.device:
112
+ if isinstance(device, str):
113
+ return torch.device(device)
114
+ elif isinstance(device, torch.device):
115
+ return device
116
+ elif isinstance(device, int):
117
+ return torch.device(f"cuda:{device}")
118
+ else:
119
+ return torch.device("cuda:0")
120
+
121
+
122
+ def load_config(
123
+ name: ModelVersion = ModelVersion.flux_dev,
124
+ flux_path: str | None = None,
125
+ ae_path: str | None = None,
126
+ text_enc_path: str | None = None,
127
+ text_enc_device: str | torch.device | None = None,
128
+ ae_device: str | torch.device | None = None,
129
+ flux_device: str | torch.device | None = None,
130
+ flow_dtype: str = "float16",
131
+ ae_dtype: str = "bfloat16",
132
+ text_enc_dtype: str = "bfloat16",
133
+ num_to_quant: Optional[int] = 20,
134
+ compile_extras: bool = False,
135
+ compile_blocks: bool = False,
136
+ offload_text_enc: bool = False,
137
+ offload_ae: bool = False,
138
+ offload_flow: bool = False,
139
+ quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
140
+ quant_ae: bool = False,
141
+ prequantized_flow: bool = False,
142
+ quantize_modulation: bool = True,
143
+ quantize_flow_embedder_layers: bool = False,
144
+ ) -> ModelSpec:
145
+ """
146
+ Load a model configuration using the passed arguments.
147
+ """
148
+ text_enc_device = str(parse_device(text_enc_device))
149
+ ae_device = str(parse_device(ae_device))
150
+ flux_device = str(parse_device(flux_device))
151
+ return ModelSpec(
152
+ version=name,
153
+ repo_id=(
154
+ "black-forest-labs/FLUX.1-dev"
155
+ if name == ModelVersion.flux_dev
156
+ else "black-forest-labs/FLUX.1-schnell"
157
+ ),
158
+ repo_flow=(
159
+ "flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft"
160
+ ),
161
+ repo_ae="ae.sft",
162
+ ckpt_path=flux_path,
163
+ params=FluxParams(
164
+ in_channels=64,
165
+ vec_in_dim=768,
166
+ context_in_dim=4096,
167
+ hidden_size=3072,
168
+ mlp_ratio=4.0,
169
+ num_heads=24,
170
+ depth=19,
171
+ depth_single_blocks=38,
172
+ axes_dim=[16, 56, 56],
173
+ theta=10_000,
174
+ qkv_bias=True,
175
+ guidance_embed=name == ModelVersion.flux_dev,
176
+ ),
177
+ ae_path=ae_path,
178
+ ae_params=AutoEncoderParams(
179
+ resolution=256,
180
+ in_channels=3,
181
+ ch=128,
182
+ out_ch=3,
183
+ ch_mult=[1, 2, 4, 4],
184
+ num_res_blocks=2,
185
+ z_channels=16,
186
+ scale_factor=0.3611,
187
+ shift_factor=0.1159,
188
+ ),
189
+ text_enc_path=text_enc_path,
190
+ text_enc_device=text_enc_device,
191
+ ae_device=ae_device,
192
+ flux_device=flux_device,
193
+ flow_dtype=flow_dtype,
194
+ ae_dtype=ae_dtype,
195
+ text_enc_dtype=text_enc_dtype,
196
+ text_enc_max_length=512 if name == ModelVersion.flux_dev else 256,
197
+ num_to_quant=num_to_quant,
198
+ compile_extras=compile_extras,
199
+ compile_blocks=compile_blocks,
200
+ offload_flow=offload_flow,
201
+ offload_text_encoder=offload_text_enc,
202
+ offload_vae=offload_ae,
203
+ text_enc_quantization_dtype={
204
+ "float8": QuantizationDtype.qfloat8,
205
+ "qint2": QuantizationDtype.qint2,
206
+ "qint4": QuantizationDtype.qint4,
207
+ "qint8": QuantizationDtype.qint8,
208
+ }.get(quant_text_enc, None),
209
+ ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
210
+ prequantized_flow=prequantized_flow,
211
+ quantize_modulation=quantize_modulation,
212
+ quantize_flow_embedder_layers=quantize_flow_embedder_layers,
213
+ )
214
+
215
+
216
+ def load_config_from_path(path: str) -> ModelSpec:
217
+ path_path = Path(path)
218
+ if not path_path.exists():
219
+ raise ValueError(f"Path {path} does not exist")
220
+ if not path_path.is_file():
221
+ raise ValueError(f"Path {path} is not a file")
222
+ return ModelSpec(**json.loads(path_path.read_text()))
223
+
224
+
225
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
226
+ if len(missing) > 0 and len(unexpected) > 0:
227
+ logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
228
+ logger.warning("\n" + "-" * 79 + "\n")
229
+ logger.warning(
230
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
231
+ )
232
+ elif len(missing) > 0:
233
+ logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
234
+ elif len(unexpected) > 0:
235
+ logger.warning(
236
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
237
+ )
238
+
239
+
240
+ def load_flow_model(config: ModelSpec) -> Flux:
241
+ ckpt_path = config.ckpt_path
242
+ FluxClass = Flux
243
+
244
+ with torch.device("meta"):
245
+ model = FluxClass(config, dtype=into_dtype(config.flow_dtype))
246
+ if not config.prequantized_flow:
247
+ model.type(into_dtype(config.flow_dtype))
248
+
249
+ if ckpt_path is not None:
250
+ # load_sft doesn't support torch.device
251
+ sd = load_sft(ckpt_path, device="cpu")
252
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
253
+ print_load_warning(missing, unexpected)
254
+ if not config.prequantized_flow:
255
+ model.type(into_dtype(config.flow_dtype))
256
+ return model
257
+
258
+
259
+ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
260
+ clip = HFEmbedder(
261
+ config.clip_path,
262
+ max_length=77,
263
+ torch_dtype=into_dtype(config.text_enc_dtype),
264
+ device=into_device(config.text_enc_device).index or 0,
265
+ is_clip=True,
266
+ quantization_dtype=config.clip_quantization_dtype,
267
+ )
268
+ t5 = HFEmbedder(
269
+ config.text_enc_path,
270
+ max_length=config.text_enc_max_length,
271
+ torch_dtype=into_dtype(config.text_enc_dtype),
272
+ device=into_device(config.text_enc_device).index or 0,
273
+ quantization_dtype=config.text_enc_quantization_dtype,
274
+ )
275
+ return clip, t5
276
+
277
+
278
+ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
279
+ ckpt_path = config.ae_path
280
+ with torch.device("meta" if ckpt_path is not None else config.ae_device):
281
+ ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype))
282
+
283
+ if ckpt_path is not None:
284
+ sd = load_sft(ckpt_path, device=str(config.ae_device))
285
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
286
+ print_load_warning(missing, unexpected)
287
+ ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype))
288
+ if config.ae_quantization_dtype is not None:
289
+ from float8_quantize import recursive_swap_linears
290
+
291
+ recursive_swap_linears(ae)
292
+ if config.offload_vae:
293
+ ae.to("cpu")
294
+ torch.cuda.empty_cache()
295
+ return ae
296
+
297
+
298
+ class LoadedModels(BaseModel):
299
+ flow: Flux
300
+ ae: AutoEncoder
301
+ clip: HFEmbedder
302
+ t5: HFEmbedder
303
+ config: ModelSpec
304
+
305
+ model_config = {
306
+ "arbitrary_types_allowed": True,
307
+ "use_enum_values": True,
308
+ }
309
+
310
+
311
+ def load_models_from_config_path(
312
+ path: str,
313
+ ) -> LoadedModels:
314
+ config = load_config_from_path(path)
315
+ clip, t5 = load_text_encoders(config)
316
+ return LoadedModels(
317
+ flow=load_flow_model(config),
318
+ ae=load_autoencoder(config),
319
+ clip=clip,
320
+ t5=t5,
321
+ config=config,
322
+ )
323
+
324
+
325
+ def load_models_from_config(config: ModelSpec) -> LoadedModels:
326
+ clip, t5 = load_text_encoders(config)
327
+ return LoadedModels(
328
+ flow=load_flow_model(config),
329
+ ae=load_autoencoder(config),
330
+ clip=clip,
331
+ t5=t5,
332
+ config=config,
333
+ )