# Update the sample_solver choices to include 'vanilla' @@ -1035,7 +1035,7 @@ with gr.Blocks( with gr.Row(): wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") - wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) # Add exclude_single_blocks checkbox for WanX-i2v @@ -1035,6 +1035,7 @@ with gr.Blocks( with gr.Row(): wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) # Add LoRA support to WanX-i2v tab @@ -979,7 +979,27 @@ with gr.Blocks( ) wanx_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + # Add LoRA section for WanX-i2v similar to other tabs + wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_lora_weights = [] + wanx_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) wanx_task = gr.Dropdown( @@ -992,6 +1012,7 @@ with gr.Blocks( wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") wanx_save_path = gr.Textbox(label="Save Path", value="outputs") # Update WanX-t2v sample solver choices @@ -1099,7 +1099,7 @@ with gr.Blocks( with gr.Row(): wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") - wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, info="Max 39 for 14B model, 29 for 1.3B model") # Add exclude_single_blocks checkbox for WanX-t2v @@ -1099,6 +1099,7 @@ with gr.Blocks( with gr.Row(): wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, info="Max 39 for 14B model, 29 for 1.3B model") # Add LoRA support to WanX-t2v tab @@ -1063,7 +1064,27 @@ with gr.Blocks( ) wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + # Add LoRA section for WanX-t2v + wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_t2v_lora_weights = [] + wanx_t2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_t2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_t2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) wanx_t2v_task = gr.Dropdown( @@ -1077,6 +1098,7 @@ with gr.Blocks( wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") + wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") # Update wanx_generate_video function to include LoRA and exclude_single_blocks @@ -2051,6 +2073,15 @@ def wanx_generate_video( save_path, output_type, sample_solver, + exclude_single_blocks, attn_mode, block_swap, fp8, - fp8_t5 + fp8_t5, + lora_folder, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0 ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate video with WanX model (supports both i2v and t2v)""" global stop_event @@ -2107,6 +2138,20 @@ def wanx_generate_video( if fp8_t5: command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) print(f"Running: {' '.join(command)}") # Update wanx_generate_video_batch function @@ -2176,9 +2221,19 @@ def wanx_generate_video_batch( save_path, output_type, sample_solver, + exclude_single_blocks, attn_mode, block_swap, fp8, - fp8_t5, + fp8_t5, + lora_folder, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, batch_size=1, input_image=None # Make input_image optional and place it at the end ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: @@ -2201,9 +2256,19 @@ def wanx_generate_video_batch( save_path, output_type, sample_solver, + exclude_single_blocks, attn_mode, block_swap, fp8, - fp8_t5 + fp8_t5, + lora_folder, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier ), outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], queue=True # Update WanX-i2v generate button click handler @@ -2423,6 +2488,15 @@ def wanx_generate_btn.click( wanx_save_path, wanx_output_type, wanx_sample_solver, + wanx_exclude_single_blocks, wanx_attn_mode, wanx_block_swap, wanx_fp8, - wanx_fp8_t5, + wanx_fp8_t5, + wanx_lora_folder, + *wanx_lora_weights, + *wanx_lora_multipliers, wanx_batch_size, wanx_input # Include the image input for this tab ], outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], queue=True ) + + # Add refresh button handler for WanX-i2v tab + wanx_refresh_outputs = [] + for i in range(4): + wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) + + wanx_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, + outputs=wanx_refresh_outputs + ) # Update WanX-t2v generate button click handler @@ -2470,9 +2544,19 @@ def wanx_t2v_generate_btn.click( wanx_t2v_save_path, wanx_t2v_output_type, wanx_t2v_sample_solver, + wanx_t2v_exclude_single_blocks, wanx_t2v_attn_mode, wanx_t2v_block_swap, wanx_t2v_fp8, - wanx_t2v_fp8_t5, + wanx_t2v_fp8_t5, + wanx_t2v_lora_folder, + *wanx_t2v_lora_weights, + *wanx_t2v_lora_multipliers, wanx_t2v_batch_size, # input_image is now optional and not included here ], outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], queue=True ) + + # Add refresh button handler for WanX-t2v tab + wanx_t2v_refresh_outputs = [] + for i in range(4): + wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) + + wanx_t2v_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, + outputs=wanx_t2v_refresh_outputs + )