yiren98 commited on
Commit
aefdc19
·
1 Parent(s): 993b150
Files changed (1) hide show
  1. gradio_app_asy.py +31 -73
gradio_app_asy.py CHANGED
@@ -1,4 +1,4 @@
1
- # import spaces
2
  import gradio as gr
3
  import torch
4
  import numpy as np
@@ -25,48 +25,8 @@ logging.basicConfig(level=logging.DEBUG)
25
 
26
  accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
27
 
28
- # hf_token = os.getenv("HF_TOKEN")
29
- # login(token=hf_token)
30
-
31
- # # Model paths dynamically retrieved using selected model
32
- # model_paths = {
33
- # 'Wood Sculpture': {
34
- # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
35
- # 'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp16.safetensors",
36
- # 'LORA_REPO': "showlab/makeanything",
37
- # 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
38
- # "Frame": 4
39
- # },
40
- # 'LEGO': {
41
- # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
42
- # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp16.safetensors",
43
- # 'LORA_REPO': "showlab/makeanything",
44
- # 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
45
- # "Frame": 9
46
- # },
47
- # 'Sketch': {
48
- # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
49
- # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp16.safetensors",
50
- # 'LORA_REPO': "showlab/makeanything",
51
- # 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
52
- # "Frame": 9
53
- # },
54
- # 'Portrait': {
55
- # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
56
- # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp16.safetensors",
57
- # 'LORA_REPO': "showlab/makeanything",
58
- # 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
59
- # "Frame": 9
60
- # }
61
- # }
62
-
63
- # # Common paths
64
- # clip_repo_id = "comfyanonymous/flux_text_encoders"
65
- # t5xxl_file = "t5xxl_fp16.safetensors"
66
- # clip_l_file = "clip_l.safetensors"
67
- # ae_repo_id = "black-forest-labs/FLUX.1-dev"
68
- # ae_file = "ae.safetensors"
69
-
70
  domain_index = {
71
  'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5,
72
  'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10,
@@ -76,15 +36,19 @@ domain_index = {
76
  }
77
 
78
  lora_paths = {
79
- "9 frame": "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors",
80
- "4 frame": "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_4f_general.safetensors"
81
  }
82
- BASE_FLUX_CHECKPOINT = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/unet/flux1-dev-fp8.safetensors"
83
- # LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors"
84
- CLIP_L_PATH = "/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
85
- T5XXL_PATH = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/clip/t5xxl_fp8_e4m3fn.safetensors"
86
- AE_PATH = "/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
87
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Model placeholders
90
  model = None
@@ -100,22 +64,25 @@ def download_file(repo_id, file_name):
100
  # Load model function with dynamic paths based on the selected model
101
  def load_target_model(frame, domain):
102
  global model, clip_l, t5xxl, ae, lora_model
 
 
 
 
 
103
 
104
  logger.info("Loading models...")
105
  # try:
106
  if model is None is None or clip_l is None or t5xxl is None or ae is None:
 
 
 
107
  clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
108
  clip_l.eval()
109
  t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
110
  t5xxl.eval()
111
  ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
112
  logger.info("Models loaded successfully.")
113
- # Load models
114
- _, model = flux_utils.load_flow_model(
115
- BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
116
- )
117
  # Load LoRA weights
118
- LORA_WEIGHTS_PATH = lora_paths[frame]
119
  multiplier = 1.0
120
  weights_sd = load_file(LORA_WEIGHTS_PATH)
121
  lora_ups_num = 10 if frame=="9 frame" else 21
@@ -131,12 +98,8 @@ def load_target_model(frame, domain):
131
  logger.info("Models loaded successfully.")
132
  return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain)
133
 
134
- # except Exception as e:
135
- # logger.error(f"Error loading models: {e}")
136
- # return f"Error loading models: {e}"
137
-
138
  # The function to generate image from a prompt and conditional image
139
- # @spaces.GPU(duration=180)
140
  def infer(prompt, frame, seed=0):
141
  global model, clip_l, t5xxl, ae, lora_model
142
  if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
@@ -257,19 +220,14 @@ with gr.Blocks() as demo:
257
 
258
  with gr.Row():
259
  with gr.Column(scale=1):
260
- with gr.Row():
261
- with gr.Column(scale=1):
262
- frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Floor")
263
- with gr.Column(scale=2):
264
- domain_selector = gr.Dropdown(choices=[], label="Select Domains")
265
-
266
- # Load Model Button
267
- load_button = gr.Button("Load Model")
268
-
269
  with gr.Column(scale=1):
270
- # Status message box
271
- status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
272
-
 
 
 
273
  with gr.Row():
274
  with gr.Column(scale=1):
275
  # Input for the prompt
@@ -326,4 +284,4 @@ with gr.Blocks() as demo:
326
  # )
327
 
328
  # Launch the Gradio app
329
- demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  import numpy as np
 
25
 
26
  accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
27
 
28
+ hf_token = os.getenv("HF_TOKEN")
29
+ login(token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  domain_index = {
31
  'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5,
32
  'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10,
 
36
  }
37
 
38
  lora_paths = {
39
+ "9 frame": "asymmetric_lora/asymmetric_lora_9f_general.safetensors",
40
+ "4 frame": "asymmetric_lora/asymmetric_lora_4f_general.safetensors"
41
  }
 
 
 
 
 
42
 
43
+ # Common paths
44
+ flux_repo_id="Kijai/flux-fp8"
45
+ flux_file="flux1-dev-fp8.safetensors"
46
+ lora_repo_id="showlab/makeanything"
47
+ clip_repo_id = "comfyanonymous/flux_text_encoders"
48
+ t5xxl_file = "t5xxl_fp16.safetensors"
49
+ clip_l_file = "clip_l.safetensors"
50
+ ae_repo_id = "black-forest-labs/FLUX.1-dev"
51
+ ae_file = "ae.safetensors"
52
 
53
  # Model placeholders
54
  model = None
 
64
  # Load model function with dynamic paths based on the selected model
65
  def load_target_model(frame, domain):
66
  global model, clip_l, t5xxl, ae, lora_model
67
+ BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file)
68
+ CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
69
+ T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
70
+ AE_PATH = download_file(ae_repo_id, ae_file)
71
+ LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame])
72
 
73
  logger.info("Loading models...")
74
  # try:
75
  if model is None is None or clip_l is None or t5xxl is None or ae is None:
76
+ _, model = flux_utils.load_flow_model(
77
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
78
+ )
79
  clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
80
  clip_l.eval()
81
  t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
82
  t5xxl.eval()
83
  ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
84
  logger.info("Models loaded successfully.")
 
 
 
 
85
  # Load LoRA weights
 
86
  multiplier = 1.0
87
  weights_sd = load_file(LORA_WEIGHTS_PATH)
88
  lora_ups_num = 10 if frame=="9 frame" else 21
 
98
  logger.info("Models loaded successfully.")
99
  return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain)
100
 
 
 
 
 
101
  # The function to generate image from a prompt and conditional image
102
+ @spaces.GPU(duration=180)
103
  def infer(prompt, frame, seed=0):
104
  global model, clip_l, t5xxl, ae, lora_model
105
  if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
 
220
 
221
  with gr.Row():
222
  with gr.Column(scale=1):
223
+ frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Model")
 
 
 
 
 
 
 
 
224
  with gr.Column(scale=1):
225
+ load_button = gr.Button("Load Model")
226
+ with gr.Column(scale=2):
227
+ status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=2)
228
+ with gr.Column(scale=2):
229
+ domain_selector = gr.Dropdown(choices=[], label="Select Domains")
230
+
231
  with gr.Row():
232
  with gr.Column(scale=1):
233
  # Input for the prompt
 
284
  # )
285
 
286
  # Launch the Gradio app
287
+ demo.launch()