main
Browse files- gradio_app_asy.py +31 -73
gradio_app_asy.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import numpy as np
|
@@ -25,48 +25,8 @@ logging.basicConfig(level=logging.DEBUG)
|
|
25 |
|
26 |
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
# # Model paths dynamically retrieved using selected model
|
32 |
-
# model_paths = {
|
33 |
-
# 'Wood Sculpture': {
|
34 |
-
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
35 |
-
# 'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp16.safetensors",
|
36 |
-
# 'LORA_REPO': "showlab/makeanything",
|
37 |
-
# 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
|
38 |
-
# "Frame": 4
|
39 |
-
# },
|
40 |
-
# 'LEGO': {
|
41 |
-
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
42 |
-
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp16.safetensors",
|
43 |
-
# 'LORA_REPO': "showlab/makeanything",
|
44 |
-
# 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
|
45 |
-
# "Frame": 9
|
46 |
-
# },
|
47 |
-
# 'Sketch': {
|
48 |
-
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
49 |
-
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp16.safetensors",
|
50 |
-
# 'LORA_REPO': "showlab/makeanything",
|
51 |
-
# 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
|
52 |
-
# "Frame": 9
|
53 |
-
# },
|
54 |
-
# 'Portrait': {
|
55 |
-
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
56 |
-
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp16.safetensors",
|
57 |
-
# 'LORA_REPO': "showlab/makeanything",
|
58 |
-
# 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
|
59 |
-
# "Frame": 9
|
60 |
-
# }
|
61 |
-
# }
|
62 |
-
|
63 |
-
# # Common paths
|
64 |
-
# clip_repo_id = "comfyanonymous/flux_text_encoders"
|
65 |
-
# t5xxl_file = "t5xxl_fp16.safetensors"
|
66 |
-
# clip_l_file = "clip_l.safetensors"
|
67 |
-
# ae_repo_id = "black-forest-labs/FLUX.1-dev"
|
68 |
-
# ae_file = "ae.safetensors"
|
69 |
-
|
70 |
domain_index = {
|
71 |
'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5,
|
72 |
'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10,
|
@@ -76,15 +36,19 @@ domain_index = {
|
|
76 |
}
|
77 |
|
78 |
lora_paths = {
|
79 |
-
"9 frame": "
|
80 |
-
"4 frame": "
|
81 |
}
|
82 |
-
BASE_FLUX_CHECKPOINT = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/unet/flux1-dev-fp8.safetensors"
|
83 |
-
# LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors"
|
84 |
-
CLIP_L_PATH = "/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
|
85 |
-
T5XXL_PATH = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/clip/t5xxl_fp8_e4m3fn.safetensors"
|
86 |
-
AE_PATH = "/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Model placeholders
|
90 |
model = None
|
@@ -100,22 +64,25 @@ def download_file(repo_id, file_name):
|
|
100 |
# Load model function with dynamic paths based on the selected model
|
101 |
def load_target_model(frame, domain):
|
102 |
global model, clip_l, t5xxl, ae, lora_model
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
logger.info("Loading models...")
|
105 |
# try:
|
106 |
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
|
|
|
|
|
|
107 |
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
108 |
clip_l.eval()
|
109 |
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
110 |
t5xxl.eval()
|
111 |
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
112 |
logger.info("Models loaded successfully.")
|
113 |
-
# Load models
|
114 |
-
_, model = flux_utils.load_flow_model(
|
115 |
-
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
116 |
-
)
|
117 |
# Load LoRA weights
|
118 |
-
LORA_WEIGHTS_PATH = lora_paths[frame]
|
119 |
multiplier = 1.0
|
120 |
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
121 |
lora_ups_num = 10 if frame=="9 frame" else 21
|
@@ -131,12 +98,8 @@ def load_target_model(frame, domain):
|
|
131 |
logger.info("Models loaded successfully.")
|
132 |
return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain)
|
133 |
|
134 |
-
# except Exception as e:
|
135 |
-
# logger.error(f"Error loading models: {e}")
|
136 |
-
# return f"Error loading models: {e}"
|
137 |
-
|
138 |
# The function to generate image from a prompt and conditional image
|
139 |
-
|
140 |
def infer(prompt, frame, seed=0):
|
141 |
global model, clip_l, t5xxl, ae, lora_model
|
142 |
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
@@ -257,19 +220,14 @@ with gr.Blocks() as demo:
|
|
257 |
|
258 |
with gr.Row():
|
259 |
with gr.Column(scale=1):
|
260 |
-
|
261 |
-
with gr.Column(scale=1):
|
262 |
-
frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Floor")
|
263 |
-
with gr.Column(scale=2):
|
264 |
-
domain_selector = gr.Dropdown(choices=[], label="Select Domains")
|
265 |
-
|
266 |
-
# Load Model Button
|
267 |
-
load_button = gr.Button("Load Model")
|
268 |
-
|
269 |
with gr.Column(scale=1):
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
273 |
with gr.Row():
|
274 |
with gr.Column(scale=1):
|
275 |
# Input for the prompt
|
@@ -326,4 +284,4 @@ with gr.Blocks() as demo:
|
|
326 |
# )
|
327 |
|
328 |
# Launch the Gradio app
|
329 |
-
demo.launch(
|
|
|
1 |
+
import spaces
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import numpy as np
|
|
|
25 |
|
26 |
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
27 |
|
28 |
+
hf_token = os.getenv("HF_TOKEN")
|
29 |
+
login(token=hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
domain_index = {
|
31 |
'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5,
|
32 |
'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10,
|
|
|
36 |
}
|
37 |
|
38 |
lora_paths = {
|
39 |
+
"9 frame": "asymmetric_lora/asymmetric_lora_9f_general.safetensors",
|
40 |
+
"4 frame": "asymmetric_lora/asymmetric_lora_4f_general.safetensors"
|
41 |
}
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# Common paths
|
44 |
+
flux_repo_id="Kijai/flux-fp8"
|
45 |
+
flux_file="flux1-dev-fp8.safetensors"
|
46 |
+
lora_repo_id="showlab/makeanything"
|
47 |
+
clip_repo_id = "comfyanonymous/flux_text_encoders"
|
48 |
+
t5xxl_file = "t5xxl_fp16.safetensors"
|
49 |
+
clip_l_file = "clip_l.safetensors"
|
50 |
+
ae_repo_id = "black-forest-labs/FLUX.1-dev"
|
51 |
+
ae_file = "ae.safetensors"
|
52 |
|
53 |
# Model placeholders
|
54 |
model = None
|
|
|
64 |
# Load model function with dynamic paths based on the selected model
|
65 |
def load_target_model(frame, domain):
|
66 |
global model, clip_l, t5xxl, ae, lora_model
|
67 |
+
BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file)
|
68 |
+
CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
|
69 |
+
T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
|
70 |
+
AE_PATH = download_file(ae_repo_id, ae_file)
|
71 |
+
LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame])
|
72 |
|
73 |
logger.info("Loading models...")
|
74 |
# try:
|
75 |
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
76 |
+
_, model = flux_utils.load_flow_model(
|
77 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
78 |
+
)
|
79 |
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
80 |
clip_l.eval()
|
81 |
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
82 |
t5xxl.eval()
|
83 |
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
84 |
logger.info("Models loaded successfully.")
|
|
|
|
|
|
|
|
|
85 |
# Load LoRA weights
|
|
|
86 |
multiplier = 1.0
|
87 |
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
88 |
lora_ups_num = 10 if frame=="9 frame" else 21
|
|
|
98 |
logger.info("Models loaded successfully.")
|
99 |
return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain)
|
100 |
|
|
|
|
|
|
|
|
|
101 |
# The function to generate image from a prompt and conditional image
|
102 |
+
@spaces.GPU(duration=180)
|
103 |
def infer(prompt, frame, seed=0):
|
104 |
global model, clip_l, t5xxl, ae, lora_model
|
105 |
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
|
|
220 |
|
221 |
with gr.Row():
|
222 |
with gr.Column(scale=1):
|
223 |
+
frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
with gr.Column(scale=1):
|
225 |
+
load_button = gr.Button("Load Model")
|
226 |
+
with gr.Column(scale=2):
|
227 |
+
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=2)
|
228 |
+
with gr.Column(scale=2):
|
229 |
+
domain_selector = gr.Dropdown(choices=[], label="Select Domains")
|
230 |
+
|
231 |
with gr.Row():
|
232 |
with gr.Column(scale=1):
|
233 |
# Input for the prompt
|
|
|
284 |
# )
|
285 |
|
286 |
# Launch the Gradio app
|
287 |
+
demo.launch()
|