Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -27,8 +27,6 @@ OVIMX2_NAME = "ovimx2_optional"
|
|
27 |
OVIMX3_NAME = "ovimx3_optional"
|
28 |
PRTHE_NAME = "prthe_optional"
|
29 |
|
30 |
-
SUCCESSFULLY_LOADED_LORAS = {}
|
31 |
-
|
32 |
lora_definitions = {
|
33 |
CAUSVID_NAME: ("joerose/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32", "Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"),
|
34 |
PERSONVID_NAME: ("ovi054/p3r5onVid1000", None),
|
@@ -39,6 +37,9 @@ lora_definitions = {
|
|
39 |
PRTHE_NAME: ("ovi054/prwthxVid", None)
|
40 |
}
|
41 |
|
|
|
|
|
|
|
42 |
for name, (repo, filename) in lora_definitions.items():
|
43 |
print(f"Attempting to load LoRA '{name}'...")
|
44 |
try:
|
@@ -48,7 +49,7 @@ for name, (repo, filename) in lora_definitions.items():
|
|
48 |
else:
|
49 |
pipe.load_lora_weights(repo, adapter_name=name, device_map="auto")
|
50 |
print(f"✅ LoRA '{name}' loaded successfully.")
|
51 |
-
|
52 |
except Exception as e:
|
53 |
print(f"⚠️ LoRA '{name}' could not be loaded: {e}")
|
54 |
|
@@ -60,38 +61,59 @@ OPTIONAL_LORA_MAP = {
|
|
60 |
"ovi054/ovimxVid2500": OVIMX3_NAME,
|
61 |
"ovi054/prwthxVid": PRTHE_NAME,
|
62 |
}
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
68 |
|
69 |
print("Initialization complete. Gradio is starting...")
|
70 |
|
71 |
@spaces.GPU()
|
72 |
def generate(prompt, negative_prompt, width, height, num_inference_steps, optional_lora_id, progress=gr.Progress(track_tqdm=True)):
|
73 |
|
74 |
-
# --- Step 1: ALWAYS build the
|
75 |
-
|
76 |
-
active_adapters = []
|
77 |
-
adapter_weights = []
|
78 |
|
79 |
-
#
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
85 |
if optional_lora_id and optional_lora_id != "None":
|
86 |
internal_name_to_add = OPTIONAL_LORA_CHOICES.get(optional_lora_id)
|
87 |
if internal_name_to_add:
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
90 |
|
91 |
# --- Step 2: Apply the calculated state, OVERWRITING any previous state ---
|
92 |
-
#
|
93 |
-
print(f"Setting
|
94 |
-
pipe.set_adapters(
|
95 |
|
96 |
apply_cache_on_pipe(pipe)
|
97 |
|
@@ -109,8 +131,6 @@ def generate(prompt, negative_prompt, width, height, num_inference_steps, option
|
|
109 |
image = (image * 255).astype(np.uint8)
|
110 |
return Image.fromarray(image)
|
111 |
|
112 |
-
# --- No cleanup step is needed, as the next run will set its own state ---
|
113 |
-
|
114 |
|
115 |
# --- Gradio Interface ---
|
116 |
iface = gr.Interface(
|
|
|
27 |
OVIMX3_NAME = "ovimx3_optional"
|
28 |
PRTHE_NAME = "prthe_optional"
|
29 |
|
|
|
|
|
30 |
lora_definitions = {
|
31 |
CAUSVID_NAME: ("joerose/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32", "Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"),
|
32 |
PERSONVID_NAME: ("ovi054/p3r5onVid1000", None),
|
|
|
37 |
PRTHE_NAME: ("ovi054/prwthxVid", None)
|
38 |
}
|
39 |
|
40 |
+
# --- THIS ORDERED LIST IS NOW CRITICAL ---
|
41 |
+
# It defines the consistent order for the weight vector.
|
42 |
+
ALL_ADAPTER_NAMES = []
|
43 |
for name, (repo, filename) in lora_definitions.items():
|
44 |
print(f"Attempting to load LoRA '{name}'...")
|
45 |
try:
|
|
|
49 |
else:
|
50 |
pipe.load_lora_weights(repo, adapter_name=name, device_map="auto")
|
51 |
print(f"✅ LoRA '{name}' loaded successfully.")
|
52 |
+
ALL_ADAPTER_NAMES.append(name)
|
53 |
except Exception as e:
|
54 |
print(f"⚠️ LoRA '{name}' could not be loaded: {e}")
|
55 |
|
|
|
61 |
"ovi054/ovimxVid2500": OVIMX3_NAME,
|
62 |
"ovi054/prwthxVid": PRTHE_NAME,
|
63 |
}
|
64 |
+
# Filter choices to only include LoRAs that actually loaded
|
65 |
+
OPTIONAL_LORA_CHOICES = {k: v for k, v in OPTIONAL_LORA_MAP.items() if v in ALL_ADAPTER_NAMES}
|
66 |
+
|
67 |
+
|
68 |
+
# --- SET INITIAL STATE AT STARTUP ---
|
69 |
+
# Set ALL adapters as active, but control them with weights.
|
70 |
+
if ALL_ADAPTER_NAMES:
|
71 |
+
print(f"Setting up all {len(ALL_ADAPTER_NAMES)} loaded adapters in the pipeline.")
|
72 |
+
|
73 |
+
# Start with all weights at 0.0
|
74 |
+
initial_weights = [0.0] * len(ALL_ADAPTER_NAMES)
|
75 |
+
|
76 |
+
# Set the base LoRA's weight to 1.0
|
77 |
+
try:
|
78 |
+
base_lora_index = ALL_ADAPTER_NAMES.index(CAUSVID_NAME)
|
79 |
+
initial_weights[base_lora_index] = 1.0
|
80 |
+
except ValueError:
|
81 |
+
print(f"Warning: Base LoRA '{CAUSVID_NAME}' not found in the loaded list. All weights start at 0.")
|
82 |
|
83 |
+
print(f"Setting initial state: adapters={ALL_ADAPTER_NAMES}, weights={initial_weights}")
|
84 |
+
pipe.set_adapters(ALL_ADAPTER_NAMES, adapter_weights=initial_weights)
|
85 |
+
else:
|
86 |
+
print("No LoRAs were loaded.")
|
87 |
|
88 |
print("Initialization complete. Gradio is starting...")
|
89 |
|
90 |
@spaces.GPU()
|
91 |
def generate(prompt, negative_prompt, width, height, num_inference_steps, optional_lora_id, progress=gr.Progress(track_tqdm=True)):
|
92 |
|
93 |
+
# --- Step 1: ALWAYS build the full weight vector from scratch for THIS run ---
|
|
|
|
|
|
|
94 |
|
95 |
+
# Start with the default state: base LoRA on, others off.
|
96 |
+
adapter_weights = [0.0] * len(ALL_ADAPTER_NAMES)
|
97 |
+
try:
|
98 |
+
base_lora_index = ALL_ADAPTER_NAMES.index(CAUSVID_NAME)
|
99 |
+
adapter_weights[base_lora_index] = 1.0
|
100 |
+
except ValueError:
|
101 |
+
pass # Base lora was not loaded, so its weight remains 0.
|
102 |
+
|
103 |
+
# If an optional LoRA is selected, turn its weight on.
|
104 |
if optional_lora_id and optional_lora_id != "None":
|
105 |
internal_name_to_add = OPTIONAL_LORA_CHOICES.get(optional_lora_id)
|
106 |
if internal_name_to_add:
|
107 |
+
try:
|
108 |
+
optional_lora_index = ALL_ADAPTER_NAMES.index(internal_name_to_add)
|
109 |
+
adapter_weights[optional_lora_index] = 1.0
|
110 |
+
except ValueError:
|
111 |
+
print(f"Warning: Could not find index for selected LoRA '{internal_name_to_add}'. It will not be applied.")
|
112 |
|
113 |
# --- Step 2: Apply the calculated state, OVERWRITING any previous state ---
|
114 |
+
# We always pass the FULL list of adapters, just with different weights.
|
115 |
+
print(f"Setting weights for this run: {list(zip(ALL_ADAPTER_NAMES, adapter_weights))}")
|
116 |
+
pipe.set_adapters(ALL_ADAPTER_NAMES, adapter_weights=adapter_weights)
|
117 |
|
118 |
apply_cache_on_pipe(pipe)
|
119 |
|
|
|
131 |
image = (image * 255).astype(np.uint8)
|
132 |
return Image.fromarray(image)
|
133 |
|
|
|
|
|
134 |
|
135 |
# --- Gradio Interface ---
|
136 |
iface = gr.Interface(
|