ovi054 commited on
Commit
c8f7bf3
·
verified ·
1 Parent(s): 972e264

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -64
app.py CHANGED
@@ -16,70 +16,93 @@ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=to
16
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
17
  flow_shift = 1.0
18
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
19
-
20
  pipe.to(device)
21
 
22
  # --- LORA SETUP ---
23
- # Define unique names for our adapters
24
- LORA_1_NAME = "causvid_lora"
25
- LORA_2_NAME = "person_lora"
26
-
27
- # 1. Load the first base LoRA ONCE at startup
28
- print("Loading first LoRA (CausVid)...")
29
- # LORA_1_REPO = "Kijai/WanVideo_comfy"
30
- # LORA_1_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors"
31
-
32
- # LORA_1_REPO = "vrgamedevgirl84/Wan14BT2VFusioniX"
33
- # LORA_1_FILENAME = "FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
34
-
35
- LORA_1_REPO = "joerose/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32"
36
- LORA_1_FILENAME = "Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
37
-
38
- # LORA_1_REPO = "lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v"
39
- # LORA_1_FILENAME = "loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
40
-
41
- try:
42
- lora_1_path = hf_hub_download(repo_id=LORA_1_REPO, filename=LORA_1_FILENAME)
43
- # The `device_map="auto"` can sometimes help in tricky environments
44
- pipe.load_lora_weights(lora_1_path, adapter_name=LORA_1_NAME, device_map="auto")
45
- print(f"✅ Default LoRA '{LORA_1_NAME}' loaded successfully.")
46
- except Exception as e:
47
- print(f"⚠️ Default LoRA '{LORA_1_NAME}' could not be loaded: {e}")
48
-
49
- # 2. Load the second hard-coded LoRA ONCE at startup
50
- print("Loading second LoRA (Person)...")
51
- # LORA_2_REPO = "ovi054/p3r5onVid1900"
52
- # LORA_2_REPO = "ovi054/rosmxVid1500"
53
- LORA_2_REPO = "ovi054/ovimxVid1750"
54
- # Assuming the file is named "pytorch_lora_weights.safetensors" which is standard.
55
- # If it has a different name, you must specify it with the `filename` argument.
56
- try:
57
- # We load the whole repository and diffusers will find the correct file
58
- pipe.load_lora_weights(LORA_2_REPO, adapter_name=LORA_2_NAME, device_map="auto")
59
- print(f" Second LoRA '{LORA_2_NAME}' loaded successfully.")
60
- except Exception as e:
61
- print(f"⚠️ Second LoRA '{LORA_2_NAME}' could not be loaded: {e}")
62
-
63
- pipe.set_adapters([LORA_1_NAME, LORA_2_NAME], adapter_weights=[1.0, 1.0])
64
-
 
 
 
 
65
 
66
  print("Initialization complete. Gradio is starting...")
67
 
68
  @spaces.GPU()
69
- def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
70
-
71
- # --- Activate both hard-coded LoRAs for this run ---
72
- # We set the adapters at the start of every generation to ensure the state is correct.
73
- print("Activating both LoRAs for inference...")
74
- # You can adjust the weights here to change the intensity of each LoRA.
75
- # For example, [1.0, 0.8] would make the second LoRA less strong.
76
- # pipe.set_adapters([LORA_1_NAME, LORA_2_NAME], adapter_weights=[1.0, 1.0])
77
-
78
- apply_cache_on_pipe(
79
- pipe,
80
- )
81
 
 
82
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  output = pipe(
84
  prompt=prompt,
85
  negative_prompt=negative_prompt,
@@ -92,25 +115,23 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
92
  image = output.frames[0][0]
93
  image = (image * 255).astype(np.uint8)
94
  return Image.fromarray(image)
95
- finally:
96
- # It's good practice to disable the adapters after the run,
97
- # although set_adapters() at the start also handles this.
98
- # print("Disabling LoRAs after run.")
99
- # pipe.disable_lora()
100
- pass
101
 
 
 
 
102
 
 
103
  iface = gr.Interface(
104
  fn=generate,
105
  inputs=[
106
  gr.Textbox(label="Input prompt"),
107
- ],
108
- additional_inputs = [
109
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
110
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
111
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
112
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
113
- gr.Textbox(label="LoRA ID", visible=False), # Hiding the dynamic LoRA input for now
 
 
114
  ],
115
  outputs=gr.Image(label="output"),
116
  )
 
16
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
17
  flow_shift = 1.0
18
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
 
19
  pipe.to(device)
20
 
21
  # --- LORA SETUP ---
22
+ CAUSVID_NAME = "causvid_base"
23
+ PERSONVID_NAME = "personvid_optional"
24
+ ROSMX_NAME = "rosmx_optional"
25
+ OVIMX_NAME = "ovimx_optional"
26
+ OVIMX2_NAME = "ovimx2_optional"
27
+ OVIMX3_NAME = "ovimx3_optional"
28
+ PRTHE_NAME = "prthe_optional"
29
+
30
+ SUCCESSFULLY_LOADED_LORAS = {}
31
+
32
+ lora_definitions = {
33
+ CAUSVID_NAME: ("joerose/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32", "Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"),
34
+ PERSONVID_NAME: ("ovi054/p3r5onVid1000", None),
35
+ ROSMX_NAME: ("ovi054/rosmxVid1500", None),
36
+ OVIMX_NAME: ("ovi054/ovimxVid1750", None),
37
+ OVIMX2_NAME: ("ovi054/ovimxVid2250", None),
38
+ OVIMX3_NAME: ("ovi054/ovimxVid2500", None),
39
+ PRTHE_NAME: ("ovi054/prwthxVid", None)
40
+ }
41
+
42
+ for name, (repo, filename) in lora_definitions.items():
43
+ print(f"Attempting to load LoRA '{name}'...")
44
+ try:
45
+ if filename:
46
+ path = hf_hub_download(repo_id=repo, filename=filename)
47
+ pipe.load_lora_weights(path, adapter_name=name, device_map="auto")
48
+ else:
49
+ pipe.load_lora_weights(repo, adapter_name=name, device_map="auto")
50
+ print(f"✅ LoRA '{name}' loaded successfully.")
51
+ SUCCESSFULLY_LOADED_LORAS[name] = repo
52
+ except Exception as e:
53
+ print(f"⚠️ LoRA '{name}' could not be loaded: {e}")
54
+
55
+ OPTIONAL_LORA_MAP = {
56
+ "ovi054/p3r5onVid1000": PERSONVID_NAME,
57
+ "ovi054/rosmxVid1500": ROSMX_NAME,
58
+ "ovi054/ovimxVid1750": OVIMX_NAME,
59
+ "ovi054/ovimxVid2250": OVIMX2_NAME,
60
+ "ovi054/ovimxVid2500": OVIMX3_NAME,
61
+ "ovi054/prwthxVid": PRTHE_NAME,
62
+ }
63
+ OPTIONAL_LORA_CHOICES = {k: v for k, v in OPTIONAL_LORA_MAP.items() if v in SUCCESSFULLY_LOADED_LORAS}
64
+
65
+ # At startup, disable all adapters. They will be selectively enabled during each run.
66
+ print("Disabling all LoRAs at startup. They will be activated on-demand.")
67
+ pipe.disable_lora()
68
 
69
  print("Initialization complete. Gradio is starting...")
70
 
71
  @spaces.GPU()
72
+ def generate(prompt, negative_prompt, width, height, num_inference_steps, optional_lora_id, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Using a try...finally block is robust for state management in apps.
75
  try:
76
+ # --- Step 1: Build the list of ACTIVE adapters and their weights for THIS run ---
77
+
78
+ active_adapters = []
79
+ adapter_weights = []
80
+
81
+ # Always include the base LoRA if it was loaded successfully
82
+ if CAUSVID_NAME in SUCCESSFULLY_LOADED_LORAS:
83
+ active_adapters.append(CAUSVID_NAME)
84
+ adapter_weights.append(1.0)
85
+
86
+ # If an optional LoRA is selected, add it to the list
87
+ if optional_lora_id and optional_lora_id != "None":
88
+ internal_name_to_add = OPTIONAL_LORA_CHOICES.get(optional_lora_id)
89
+ if internal_name_to_add:
90
+ active_adapters.append(internal_name_to_add)
91
+ adapter_weights.append(1.0)
92
+
93
+ # --- Step 2: Apply the adapters and weights for this run using the correct function ---
94
+ if active_adapters:
95
+ print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
96
+ # This is the correct, modern way to set adapters and their weights.
97
+ pipe.set_adapters(active_adapters, adapter_weights=adapter_weights)
98
+ else:
99
+ print("No LoRAs are active for this run.")
100
+ # ensure all are disabled if for some reason none were selected
101
+ pipe.disable_lora()
102
+
103
+ apply_cache_on_pipe(pipe)
104
+
105
+ # --- Step 3: Run inference ---
106
  output = pipe(
107
  prompt=prompt,
108
  negative_prompt=negative_prompt,
 
115
  image = output.frames[0][0]
116
  image = (image * 255).astype(np.uint8)
117
  return Image.fromarray(image)
 
 
 
 
 
 
118
 
119
+ finally:
120
+ print("Disabling LoRAs after run to reset state.")
121
+ pipe.disable_lora()
122
 
123
+ # --- Gradio Interface ---
124
  iface = gr.Interface(
125
  fn=generate,
126
  inputs=[
127
  gr.Textbox(label="Input prompt"),
 
 
128
  gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
129
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
130
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
131
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
132
+ gr.Textbox(
133
+ label="Optional LoRA",
134
+ )
135
  ],
136
  outputs=gr.Image(label="output"),
137
  )