ovi054 commited on
Commit
3e81ff5
·
verified ·
1 Parent(s): 17aa94d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -41
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import numpy as np
8
  import gradio as gr
9
  import spaces
 
10
 
11
  # --- INITIAL SETUP ---
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -28,7 +29,6 @@ print("Moving pipeline to device (ZeroGPU will handle offloading)...")
28
  pipe.to(device)
29
 
30
  # --- LORA SETUP ---
31
- # We will NOT fuse anything. Everything will be handled dynamically.
32
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
33
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
34
  BASE_LORA_NAME = "causvid_lora"
@@ -44,66 +44,48 @@ except Exception as e:
44
 
45
  print("Initialization complete. Gradio is starting...")
46
 
47
- def move_adapter_to_device(pipe, adapter_name, device):
48
- """
49
- Surgically moves only the parameters of a specific LoRA adapter to the target device.
50
- This avoids touching the base model's meta tensors.
51
- """
52
- print(f"Moving adapter '{adapter_name}' to {device}...")
53
- for param in pipe.transformer.parameters():
54
- if hasattr(param, "adapter_name") and param.adapter_name == adapter_name:
55
- param.data = param.data.to(device, non_blocking=True)
56
- if param.grad is not None:
57
- param.grad.data = param.grad.data.to(device, non_blocking=True)
58
- print(f"✅ Adapter '{adapter_name}' moved.")
59
 
60
  @spaces.GPU()
61
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
62
 
63
  # --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
64
- # Start with a clean slate by disabling any active adapters from previous runs
65
- pipe.disable_lora()
66
-
67
  active_adapters = []
68
  adapter_weights = []
69
 
70
- # 1. Load the Base LoRA
71
  if causvid_path:
72
  try:
73
- # We load it for every run to ensure a clean state
74
  print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
75
- pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME)
76
-
77
- # THE CRITICAL FIX: Move only this adapter's weights to the GPU
78
- move_adapter_to_device(pipe, BASE_LORA_NAME, device)
79
-
80
  active_adapters.append(BASE_LORA_NAME)
81
  adapter_weights.append(1.0)
 
82
  except Exception as e:
83
  print(f"⚠️ Failed to load base LoRA: {e}")
84
 
85
- # 2. Load the Custom LoRA if provided
86
  clean_lora_id = lora_id.strip() if lora_id else ""
87
  if clean_lora_id:
88
  try:
89
  print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
90
- pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
91
-
92
- # THE CRITICAL FIX: Move only this adapter's weights to the GPU
93
- move_adapter_to_device(pipe, CUSTOM_LORA_NAME, device)
94
-
95
  active_adapters.append(CUSTOM_LORA_NAME)
96
  adapter_weights.append(1.0)
 
97
  except Exception as e:
98
  print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
99
- # If it fails, delete the adapter config to prevent issues
100
- if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
101
  del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
102
 
103
  # 3. Activate the successfully loaded adapters
104
  if active_adapters:
105
  print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
106
  pipe.set_adapters(active_adapters, adapter_weights)
 
 
 
107
 
108
  apply_cache_on_pipe(pipe)
109
 
@@ -122,16 +104,13 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
122
  return Image.fromarray(image)
123
  finally:
124
  # --- PROPER CLEANUP ---
125
- print("Cleaning up LoRAs for this run...")
126
- # Disable adapters to stop them from being used
127
- pipe.disable_lora()
128
-
129
- # Delete the LoRA configs from the model to truly unload them
130
- if BASE_LORA_NAME in pipe.transformer.peft_config:
131
- del pipe.transformer.peft_config[BASE_LORA_NAME]
132
- if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
133
- del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
134
- print("✅ LoRAs cleaned up.")
135
 
136
 
137
  iface = gr.Interface(
 
7
  import numpy as np
8
  import gradio as gr
9
  import spaces
10
+ import gc
11
 
12
  # --- INITIAL SETUP ---
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
29
  pipe.to(device)
30
 
31
  # --- LORA SETUP ---
 
32
  CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
33
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
34
  BASE_LORA_NAME = "causvid_lora"
 
44
 
45
  print("Initialization complete. Gradio is starting...")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @spaces.GPU()
49
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
50
 
51
  # --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
 
 
 
52
  active_adapters = []
53
  adapter_weights = []
54
 
55
+ # 1. Load the Base LoRA directly onto the correct device
56
  if causvid_path:
57
  try:
 
58
  print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
59
+ # THE CORRECT FIX: Use device_map to load the LoRA directly to the GPU.
60
+ pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME, device_map={"":device})
 
 
 
61
  active_adapters.append(BASE_LORA_NAME)
62
  adapter_weights.append(1.0)
63
+ print("✅ Base LoRA loaded to device.")
64
  except Exception as e:
65
  print(f"⚠️ Failed to load base LoRA: {e}")
66
 
67
+ # 2. Load the Custom LoRA if provided, also directly to the device
68
  clean_lora_id = lora_id.strip() if lora_id else ""
69
  if clean_lora_id:
70
  try:
71
  print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
72
+ # THE CORRECT FIX: Also use device_map for the custom LoRA.
73
+ pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME, device_map={"":device})
 
 
 
74
  active_adapters.append(CUSTOM_LORA_NAME)
75
  adapter_weights.append(1.0)
76
+ print("✅ Custom LoRA loaded to device.")
77
  except Exception as e:
78
  print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
79
+ if CUSTOM_LORA_NAME in getattr(pipe.transformer, 'peft_config', {}):
 
80
  del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
81
 
82
  # 3. Activate the successfully loaded adapters
83
  if active_adapters:
84
  print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
85
  pipe.set_adapters(active_adapters, adapter_weights)
86
+ else:
87
+ # Ensure LoRA is disabled if no adapters were loaded
88
+ pipe.disable_lora()
89
 
90
  apply_cache_on_pipe(pipe)
91
 
 
104
  return Image.fromarray(image)
105
  finally:
106
  # --- PROPER CLEANUP ---
107
+ # The most reliable way to clean up in this complex environment is to unload ALL LoRAs.
108
+ # This avoids leaving dangling configs.
109
+ print("Unloading all LoRAs to ensure a clean state...")
110
+ pipe.unload_lora_weights()
111
+ gc.collect() # Force garbage collection
112
+ torch.cuda.empty_cache() # Clear CUDA cache
113
+ print("✅ LoRAs unloaded and memory cleaned.")
 
 
 
114
 
115
 
116
  iface = gr.Interface(