ovi054 commited on
Commit
a474898
·
verified ·
1 Parent(s): be31f5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -27,8 +27,6 @@ OVIMX2_NAME = "ovimx2_optional"
27
  OVIMX3_NAME = "ovimx3_optional"
28
  PRTHE_NAME = "prthe_optional"
29
 
30
- SUCCESSFULLY_LOADED_LORAS = {}
31
-
32
  lora_definitions = {
33
  CAUSVID_NAME: ("joerose/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32", "Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"),
34
  PERSONVID_NAME: ("ovi054/p3r5onVid1000", None),
@@ -39,6 +37,9 @@ lora_definitions = {
39
  PRTHE_NAME: ("ovi054/prwthxVid", None)
40
  }
41
 
 
 
 
42
  for name, (repo, filename) in lora_definitions.items():
43
  print(f"Attempting to load LoRA '{name}'...")
44
  try:
@@ -48,7 +49,7 @@ for name, (repo, filename) in lora_definitions.items():
48
  else:
49
  pipe.load_lora_weights(repo, adapter_name=name, device_map="auto")
50
  print(f"✅ LoRA '{name}' loaded successfully.")
51
- SUCCESSFULLY_LOADED_LORAS[name] = repo
52
  except Exception as e:
53
  print(f"⚠️ LoRA '{name}' could not be loaded: {e}")
54
 
@@ -60,38 +61,59 @@ OPTIONAL_LORA_MAP = {
60
  "ovi054/ovimxVid2500": OVIMX3_NAME,
61
  "ovi054/prwthxVid": PRTHE_NAME,
62
  }
63
- OPTIONAL_LORA_CHOICES = {k: v for k, v in OPTIONAL_LORA_MAP.items() if v in SUCCESSFULLY_LOADED_LORAS}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # At startup, we set a known initial state. Setting no adapters is the cleanest start.
66
- print("Setting a clean initial state (no adapters active).")
67
- pipe.set_adapters([], adapter_weights=[])
 
68
 
69
  print("Initialization complete. Gradio is starting...")
70
 
71
  @spaces.GPU()
72
  def generate(prompt, negative_prompt, width, height, num_inference_steps, optional_lora_id, progress=gr.Progress(track_tqdm=True)):
73
 
74
- # --- Step 1: ALWAYS build the desired state from scratch for THIS run ---
75
-
76
- active_adapters = []
77
- adapter_weights = []
78
 
79
- # Always include the base LoRA if it was loaded successfully
80
- if CAUSVID_NAME in SUCCESSFULLY_LOADED_LORAS:
81
- active_adapters.append(CAUSVID_NAME)
82
- adapter_weights.append(1.0)
83
-
84
- # If an optional LoRA is selected, add it to the list
 
 
 
85
  if optional_lora_id and optional_lora_id != "None":
86
  internal_name_to_add = OPTIONAL_LORA_CHOICES.get(optional_lora_id)
87
  if internal_name_to_add:
88
- active_adapters.append(internal_name_to_add)
89
- adapter_weights.append(1.0)
 
 
 
90
 
91
  # --- Step 2: Apply the calculated state, OVERWRITING any previous state ---
92
- # This single call is the source of truth for the run.
93
- print(f"Setting adapters for this run: {active_adapters} with weights: {adapter_weights}")
94
- pipe.set_adapters(active_adapters, adapter_weights=adapter_weights)
95
 
96
  apply_cache_on_pipe(pipe)
97
 
@@ -109,8 +131,6 @@ def generate(prompt, negative_prompt, width, height, num_inference_steps, option
109
  image = (image * 255).astype(np.uint8)
110
  return Image.fromarray(image)
111
 
112
- # --- No cleanup step is needed, as the next run will set its own state ---
113
-
114
 
115
  # --- Gradio Interface ---
116
  iface = gr.Interface(
 
27
  OVIMX3_NAME = "ovimx3_optional"
28
  PRTHE_NAME = "prthe_optional"
29
 
 
 
30
  lora_definitions = {
31
  CAUSVID_NAME: ("joerose/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32", "Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"),
32
  PERSONVID_NAME: ("ovi054/p3r5onVid1000", None),
 
37
  PRTHE_NAME: ("ovi054/prwthxVid", None)
38
  }
39
 
40
+ # --- THIS ORDERED LIST IS NOW CRITICAL ---
41
+ # It defines the consistent order for the weight vector.
42
+ ALL_ADAPTER_NAMES = []
43
  for name, (repo, filename) in lora_definitions.items():
44
  print(f"Attempting to load LoRA '{name}'...")
45
  try:
 
49
  else:
50
  pipe.load_lora_weights(repo, adapter_name=name, device_map="auto")
51
  print(f"✅ LoRA '{name}' loaded successfully.")
52
+ ALL_ADAPTER_NAMES.append(name)
53
  except Exception as e:
54
  print(f"⚠️ LoRA '{name}' could not be loaded: {e}")
55
 
 
61
  "ovi054/ovimxVid2500": OVIMX3_NAME,
62
  "ovi054/prwthxVid": PRTHE_NAME,
63
  }
64
+ # Filter choices to only include LoRAs that actually loaded
65
+ OPTIONAL_LORA_CHOICES = {k: v for k, v in OPTIONAL_LORA_MAP.items() if v in ALL_ADAPTER_NAMES}
66
+
67
+
68
+ # --- SET INITIAL STATE AT STARTUP ---
69
+ # Set ALL adapters as active, but control them with weights.
70
+ if ALL_ADAPTER_NAMES:
71
+ print(f"Setting up all {len(ALL_ADAPTER_NAMES)} loaded adapters in the pipeline.")
72
+
73
+ # Start with all weights at 0.0
74
+ initial_weights = [0.0] * len(ALL_ADAPTER_NAMES)
75
+
76
+ # Set the base LoRA's weight to 1.0
77
+ try:
78
+ base_lora_index = ALL_ADAPTER_NAMES.index(CAUSVID_NAME)
79
+ initial_weights[base_lora_index] = 1.0
80
+ except ValueError:
81
+ print(f"Warning: Base LoRA '{CAUSVID_NAME}' not found in the loaded list. All weights start at 0.")
82
 
83
+ print(f"Setting initial state: adapters={ALL_ADAPTER_NAMES}, weights={initial_weights}")
84
+ pipe.set_adapters(ALL_ADAPTER_NAMES, adapter_weights=initial_weights)
85
+ else:
86
+ print("No LoRAs were loaded.")
87
 
88
  print("Initialization complete. Gradio is starting...")
89
 
90
  @spaces.GPU()
91
  def generate(prompt, negative_prompt, width, height, num_inference_steps, optional_lora_id, progress=gr.Progress(track_tqdm=True)):
92
 
93
+ # --- Step 1: ALWAYS build the full weight vector from scratch for THIS run ---
 
 
 
94
 
95
+ # Start with the default state: base LoRA on, others off.
96
+ adapter_weights = [0.0] * len(ALL_ADAPTER_NAMES)
97
+ try:
98
+ base_lora_index = ALL_ADAPTER_NAMES.index(CAUSVID_NAME)
99
+ adapter_weights[base_lora_index] = 1.0
100
+ except ValueError:
101
+ pass # Base lora was not loaded, so its weight remains 0.
102
+
103
+ # If an optional LoRA is selected, turn its weight on.
104
  if optional_lora_id and optional_lora_id != "None":
105
  internal_name_to_add = OPTIONAL_LORA_CHOICES.get(optional_lora_id)
106
  if internal_name_to_add:
107
+ try:
108
+ optional_lora_index = ALL_ADAPTER_NAMES.index(internal_name_to_add)
109
+ adapter_weights[optional_lora_index] = 1.0
110
+ except ValueError:
111
+ print(f"Warning: Could not find index for selected LoRA '{internal_name_to_add}'. It will not be applied.")
112
 
113
  # --- Step 2: Apply the calculated state, OVERWRITING any previous state ---
114
+ # We always pass the FULL list of adapters, just with different weights.
115
+ print(f"Setting weights for this run: {list(zip(ALL_ADAPTER_NAMES, adapter_weights))}")
116
+ pipe.set_adapters(ALL_ADAPTER_NAMES, adapter_weights=adapter_weights)
117
 
118
  apply_cache_on_pipe(pipe)
119
 
 
131
  image = (image * 255).astype(np.uint8)
132
  return Image.fromarray(image)
133
 
 
 
134
 
135
  # --- Gradio Interface ---
136
  iface = gr.Interface(