Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
|
|
7 |
import numpy as np
|
8 |
import gradio as gr
|
9 |
import spaces
|
|
|
10 |
|
11 |
# --- INITIAL SETUP ---
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -28,7 +29,6 @@ print("Moving pipeline to device (ZeroGPU will handle offloading)...")
|
|
28 |
pipe.to(device)
|
29 |
|
30 |
# --- LORA SETUP ---
|
31 |
-
# We will NOT fuse anything. Everything will be handled dynamically.
|
32 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
33 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
34 |
BASE_LORA_NAME = "causvid_lora"
|
@@ -44,66 +44,48 @@ except Exception as e:
|
|
44 |
|
45 |
print("Initialization complete. Gradio is starting...")
|
46 |
|
47 |
-
def move_adapter_to_device(pipe, adapter_name, device):
|
48 |
-
"""
|
49 |
-
Surgically moves only the parameters of a specific LoRA adapter to the target device.
|
50 |
-
This avoids touching the base model's meta tensors.
|
51 |
-
"""
|
52 |
-
print(f"Moving adapter '{adapter_name}' to {device}...")
|
53 |
-
for param in pipe.transformer.parameters():
|
54 |
-
if hasattr(param, "adapter_name") and param.adapter_name == adapter_name:
|
55 |
-
param.data = param.data.to(device, non_blocking=True)
|
56 |
-
if param.grad is not None:
|
57 |
-
param.grad.data = param.grad.data.to(device, non_blocking=True)
|
58 |
-
print(f"✅ Adapter '{adapter_name}' moved.")
|
59 |
|
60 |
@spaces.GPU()
|
61 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
62 |
|
63 |
# --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
|
64 |
-
# Start with a clean slate by disabling any active adapters from previous runs
|
65 |
-
pipe.disable_lora()
|
66 |
-
|
67 |
active_adapters = []
|
68 |
adapter_weights = []
|
69 |
|
70 |
-
# 1. Load the Base LoRA
|
71 |
if causvid_path:
|
72 |
try:
|
73 |
-
# We load it for every run to ensure a clean state
|
74 |
print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
|
75 |
-
|
76 |
-
|
77 |
-
# THE CRITICAL FIX: Move only this adapter's weights to the GPU
|
78 |
-
move_adapter_to_device(pipe, BASE_LORA_NAME, device)
|
79 |
-
|
80 |
active_adapters.append(BASE_LORA_NAME)
|
81 |
adapter_weights.append(1.0)
|
|
|
82 |
except Exception as e:
|
83 |
print(f"⚠️ Failed to load base LoRA: {e}")
|
84 |
|
85 |
-
# 2. Load the Custom LoRA if provided
|
86 |
clean_lora_id = lora_id.strip() if lora_id else ""
|
87 |
if clean_lora_id:
|
88 |
try:
|
89 |
print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
|
90 |
-
|
91 |
-
|
92 |
-
# THE CRITICAL FIX: Move only this adapter's weights to the GPU
|
93 |
-
move_adapter_to_device(pipe, CUSTOM_LORA_NAME, device)
|
94 |
-
|
95 |
active_adapters.append(CUSTOM_LORA_NAME)
|
96 |
adapter_weights.append(1.0)
|
|
|
97 |
except Exception as e:
|
98 |
print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
|
99 |
-
|
100 |
-
if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
|
101 |
del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
|
102 |
|
103 |
# 3. Activate the successfully loaded adapters
|
104 |
if active_adapters:
|
105 |
print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
|
106 |
pipe.set_adapters(active_adapters, adapter_weights)
|
|
|
|
|
|
|
107 |
|
108 |
apply_cache_on_pipe(pipe)
|
109 |
|
@@ -122,16 +104,13 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
|
|
122 |
return Image.fromarray(image)
|
123 |
finally:
|
124 |
# --- PROPER CLEANUP ---
|
125 |
-
|
126 |
-
#
|
127 |
-
|
128 |
-
|
129 |
-
#
|
130 |
-
|
131 |
-
|
132 |
-
if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
|
133 |
-
del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
|
134 |
-
print("✅ LoRAs cleaned up.")
|
135 |
|
136 |
|
137 |
iface = gr.Interface(
|
|
|
7 |
import numpy as np
|
8 |
import gradio as gr
|
9 |
import spaces
|
10 |
+
import gc
|
11 |
|
12 |
# --- INITIAL SETUP ---
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
29 |
pipe.to(device)
|
30 |
|
31 |
# --- LORA SETUP ---
|
|
|
32 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
33 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
34 |
BASE_LORA_NAME = "causvid_lora"
|
|
|
44 |
|
45 |
print("Initialization complete. Gradio is starting...")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
@spaces.GPU()
|
49 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
50 |
|
51 |
# --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
|
|
|
|
|
|
|
52 |
active_adapters = []
|
53 |
adapter_weights = []
|
54 |
|
55 |
+
# 1. Load the Base LoRA directly onto the correct device
|
56 |
if causvid_path:
|
57 |
try:
|
|
|
58 |
print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
|
59 |
+
# THE CORRECT FIX: Use device_map to load the LoRA directly to the GPU.
|
60 |
+
pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME, device_map={"":device})
|
|
|
|
|
|
|
61 |
active_adapters.append(BASE_LORA_NAME)
|
62 |
adapter_weights.append(1.0)
|
63 |
+
print("✅ Base LoRA loaded to device.")
|
64 |
except Exception as e:
|
65 |
print(f"⚠️ Failed to load base LoRA: {e}")
|
66 |
|
67 |
+
# 2. Load the Custom LoRA if provided, also directly to the device
|
68 |
clean_lora_id = lora_id.strip() if lora_id else ""
|
69 |
if clean_lora_id:
|
70 |
try:
|
71 |
print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
|
72 |
+
# THE CORRECT FIX: Also use device_map for the custom LoRA.
|
73 |
+
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME, device_map={"":device})
|
|
|
|
|
|
|
74 |
active_adapters.append(CUSTOM_LORA_NAME)
|
75 |
adapter_weights.append(1.0)
|
76 |
+
print("✅ Custom LoRA loaded to device.")
|
77 |
except Exception as e:
|
78 |
print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
|
79 |
+
if CUSTOM_LORA_NAME in getattr(pipe.transformer, 'peft_config', {}):
|
|
|
80 |
del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
|
81 |
|
82 |
# 3. Activate the successfully loaded adapters
|
83 |
if active_adapters:
|
84 |
print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
|
85 |
pipe.set_adapters(active_adapters, adapter_weights)
|
86 |
+
else:
|
87 |
+
# Ensure LoRA is disabled if no adapters were loaded
|
88 |
+
pipe.disable_lora()
|
89 |
|
90 |
apply_cache_on_pipe(pipe)
|
91 |
|
|
|
104 |
return Image.fromarray(image)
|
105 |
finally:
|
106 |
# --- PROPER CLEANUP ---
|
107 |
+
# The most reliable way to clean up in this complex environment is to unload ALL LoRAs.
|
108 |
+
# This avoids leaving dangling configs.
|
109 |
+
print("Unloading all LoRAs to ensure a clean state...")
|
110 |
+
pipe.unload_lora_weights()
|
111 |
+
gc.collect() # Force garbage collection
|
112 |
+
torch.cuda.empty_cache() # Clear CUDA cache
|
113 |
+
print("✅ LoRAs unloaded and memory cleaned.")
|
|
|
|
|
|
|
114 |
|
115 |
|
116 |
iface = gr.Interface(
|