multimodalart HF Staff commited on
Commit
c9cb9a7
·
verified ·
1 Parent(s): 4c6ab9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -47
app.py CHANGED
@@ -67,12 +67,9 @@ pipe = DiffusionPipeline.from_pretrained(
67
  base_model, scheduler=scheduler, torch_dtype=dtype
68
  ).to(device)
69
 
70
- # Store Lightning LoRA info
71
- lightning_lora = {
72
- "repo": "lightx2v/Qwen-Image-Lightning",
73
- "weight_name": "Qwen-Image-Lightning-8steps-V1.0.safetensors",
74
- "loaded": False
75
- }
76
 
77
  MAX_SEED = np.iinfo(np.int32).max
78
 
@@ -134,29 +131,11 @@ def update_selection(evt: gr.SelectData, aspect_ratio):
134
  )
135
 
136
  def handle_speed_mode(speed_mode):
137
- """Handle the speed/quality toggle for Lightning LoRA."""
138
- global lightning_lora
139
-
140
  if speed_mode == "Speed (8 steps)":
141
- # Load Lightning LoRA if not already loaded
142
- if not lightning_lora["loaded"]:
143
- with calculateDuration("Loading Lightning LoRA"):
144
- pipe.load_lora_weights(
145
- lightning_lora["repo"],
146
- weight_name=lightning_lora["weight_name"],
147
- adapter_name="lightning"
148
- )
149
- lightning_lora["loaded"] = True
150
- return gr.update(value="Lightning LoRA loaded for fast generation"), 8, 1.0
151
- return gr.update(value="Lightning LoRA already loaded"), 8, 1.0
152
  else: # Quality mode
153
- # Unload Lightning LoRA if loaded
154
- if lightning_lora["loaded"]:
155
- with calculateDuration("Unloading Lightning LoRA"):
156
- pipe.unload_lora_weights()
157
- lightning_lora["loaded"] = False
158
- return gr.update(value="Lightning LoRA unloaded for quality generation"), 28, 3.5
159
- return gr.update(value="Quality mode active"), 28, 3.5
160
 
161
  @spaces.GPU(duration=70)
162
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
@@ -198,36 +177,35 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
198
  else:
199
  prompt_mash = prompt
200
 
201
- # First, handle Lightning LoRA based on speed mode
202
- global lightning_lora
203
- if speed_mode == "Speed (8 steps)" and not lightning_lora["loaded"]:
204
- with calculateDuration("Loading Lightning LoRA"):
 
 
 
 
205
  pipe.load_lora_weights(
206
- lightning_lora["repo"],
207
- weight_name=lightning_lora["weight_name"],
208
  adapter_name="lightning"
209
  )
210
- lightning_lora["loaded"] = True
211
- elif speed_mode == "Quality (28 steps)" and lightning_lora["loaded"]:
212
- with calculateDuration("Unloading Lightning LoRA"):
213
- pipe.unload_lora_weights()
214
- lightning_lora["loaded"] = False
215
-
216
- # Load the selected style LoRA
217
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
218
- weight_name = selected_lora.get("weights", None)
219
-
220
- # If Lightning is loaded, we need to handle multiple LoRAs
221
- if lightning_lora["loaded"]:
222
  pipe.load_lora_weights(
223
  lora_path,
224
  weight_name=weight_name,
225
  low_cpu_mem_usage=True,
226
  adapter_name="style"
227
  )
228
- # Set both adapters active
 
229
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
230
- else:
 
 
 
231
  pipe.load_lora_weights(
232
  lora_path,
233
  weight_name=weight_name,
 
67
  base_model, scheduler=scheduler, torch_dtype=dtype
68
  ).to(device)
69
 
70
+ # Lightning LoRA info (no global state)
71
+ LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
72
+ LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
 
 
 
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
 
 
131
  )
132
 
133
  def handle_speed_mode(speed_mode):
134
+ """Update UI based on speed/quality toggle."""
 
 
135
  if speed_mode == "Speed (8 steps)":
136
+ return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
 
 
 
 
 
 
 
 
 
 
137
  else: # Quality mode
138
+ return gr.update(value="Quality mode selected - 28 steps for best quality"), 28, 3.5
 
 
 
 
 
 
139
 
140
  @spaces.GPU(duration=70)
141
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
 
177
  else:
178
  prompt_mash = prompt
179
 
180
+ # Always unload any existing LoRAs first to avoid conflicts
181
+ with calculateDuration("Unloading existing LoRAs"):
182
+ pipe.unload_lora_weights()
183
+
184
+ # Load LoRAs based on speed mode
185
+ if speed_mode == "Speed (8 steps)":
186
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
187
+ # Load Lightning LoRA first
188
  pipe.load_lora_weights(
189
+ LIGHTNING_LORA_REPO,
190
+ weight_name=LIGHTNING_LORA_WEIGHT,
191
  adapter_name="lightning"
192
  )
193
+
194
+ # Load the selected style LoRA
195
+ weight_name = selected_lora.get("weights", None)
 
 
 
 
 
 
 
 
 
196
  pipe.load_lora_weights(
197
  lora_path,
198
  weight_name=weight_name,
199
  low_cpu_mem_usage=True,
200
  adapter_name="style"
201
  )
202
+
203
+ # Set both adapters active with their weights
204
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
205
+ else:
206
+ # Quality mode - only load the style LoRA
207
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
208
+ weight_name = selected_lora.get("weights", None)
209
  pipe.load_lora_weights(
210
  lora_path,
211
  weight_name=weight_name,