import gradio as gr from gradio import update as gr_update import subprocess import threading import time import re import os import random import tiktoken import sys import ffmpeg from typing import List, Tuple, Optional, Generator, Dict import json from gradio import themes from gradio.themes.utils import colors import subprocess from PIL import Image import math import cv2 # Add global stop event stop_event = threading.Event() def get_dit_models(dit_folder: str) -> List[str]: """Get list of available DiT models in the specified folder""" if not os.path.exists(dit_folder): return ["mp_rank_00_model_states.pt"] models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] models.sort(key=str.lower) return models if models else ["mp_rank_00_model_states.pt"] def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: """Update both DiT and LoRA dropdowns""" # Get model lists dit_models = get_dit_models(dit_folder) lora_choices = get_lora_options(lora_folder) # Current values processing dit_value = current_values[0] if dit_value not in dit_models: dit_value = dit_models[0] if dit_models else None weights = current_values[1:5] multipliers = current_values[5:9] results = [gr.update(choices=dit_models, value=dit_value)] # Add LoRA updates for i in range(4): weight = weights[i] if i < len(weights) else "None" multiplier = multipliers[i] if i < len(multipliers) else 1.0 if weight not in lora_choices: weight = "None" results.extend([ gr.update(choices=lora_choices, value=weight), gr.update(value=multiplier) ]) return results def extract_video_metadata(video_path: str) -> Dict: """Extract metadata from video file using ffprobe.""" cmd = [ 'ffprobe', '-v', 'quiet', '-print_format', 'json', '-show_format', video_path ] try: result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) metadata = json.loads(result.stdout.decode('utf-8')) if 'format' in metadata and 'tags' in metadata['format']: comment = metadata['format']['tags'].get('comment', '{}') return json.loads(comment) return {} except Exception as e: print(f"Metadata extraction failed: {str(e)}") return {} def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: """Map metadata parameters to Gradio components for different tabs""" mapping = { 'common': { 'prompt': ('prompt', 'v2v_prompt'), 'width': ('width', 'v2v_width'), 'height': ('height', 'v2v_height'), 'batch_size': ('batch_size', 'v2v_batch_size'), 'video_length': ('video_length', 'v2v_video_length'), 'fps': ('fps', 'v2v_fps'), 'infer_steps': ('infer_steps', 'v2v_infer_steps'), 'seed': ('seed', 'v2v_seed'), 'model': ('model', 'v2v_model'), 'vae': ('vae', 'v2v_vae'), 'te1': ('te1', 'v2v_te1'), 'te2': ('te2', 'v2v_te2'), 'save_path': ('save_path', 'v2v_save_path'), 'flow_shift': ('flow_shift', 'v2v_flow_shift'), 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), 'output_type': ('output_type', 'v2v_output_type'), 'attn_mode': ('attn_mode', 'v2v_attn_mode'), 'block_swap': ('block_swap', 'v2v_block_swap') }, 'lora': { 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] } } results = {} for param, value in metadata.items(): # Handle common parameters if param in mapping['common']: target = mapping['common'][param][0 if target_tab == 't2v' else 1] results[target] = value # Handle LoRA parameters if param == 'lora_weights': for i, weight in enumerate(value[:4]): target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] results[target] = weight if param == 'lora_multipliers': for i, mult in enumerate(value[:4]): target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] results[target] = float(mult) return results def add_metadata_to_video(video_path: str, parameters: dict) -> None: """Add generation parameters to video metadata using ffmpeg.""" import json import subprocess # Convert parameters to JSON string params_json = json.dumps(parameters, indent=2) # Temporary output path temp_path = video_path.replace(".mp4", "_temp.mp4") # FFmpeg command to add metadata without re-encoding cmd = [ 'ffmpeg', '-i', video_path, '-metadata', f'comment={params_json}', '-codec', 'copy', temp_path ] try: # Execute FFmpeg command subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Replace original file with the metadata-enhanced version os.replace(temp_path, video_path) except subprocess.CalledProcessError as e: print(f"Failed to add metadata: {e.stderr.decode()}") if os.path.exists(temp_path): os.remove(temp_path) except Exception as e: print(f"Error: {str(e)}") def count_prompt_tokens(prompt: str) -> int: enc = tiktoken.get_encoding("cl100k_base") tokens = enc.encode(prompt) return len(tokens) def get_lora_options(lora_folder: str = "lora") -> List[str]: if not os.path.exists(lora_folder): return ["None"] lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] lora_files.sort(key=str.lower) return ["None"] + lora_files def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: new_choices = get_lora_options(lora_folder) weights = current_values[:4] multipliers = current_values[4:8] results = [] for i in range(4): weight = weights[i] if i < len(weights) else "None" multiplier = multipliers[i] if i < len(multipliers) else 1.0 if weight not in new_choices: weight = "None" results.extend([ gr.update(choices=new_choices, value=weight), gr.update(value=multiplier) ]) return results def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: """Transfer selected video and prompt to Video2Video tab""" if not gallery or evt.index >= len(gallery): return None, "", selected_index.value selected_item = gallery[evt.index] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] # Update the selected index selected_index.value = evt.index return str(video_path), prompt, evt.index def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: """Send the currently selected video to V2V tab""" if not gallery or selected_index.value is None or selected_index.value >= len(gallery): return None, "" selected_item = gallery[selected_index.value] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] return str(video_path), prompt def clear_cuda_cache(): """Clear CUDA cache if available""" import torch if torch.cuda.is_available(): torch.cuda.empty_cache() # Optional: synchronize to ensure cache is cleared torch.cuda.synchronize() def process_single_video( prompt: str, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, dit_folder: str, model: str, vae: str, te1: str, te2: str, save_path: str, flow_shift: float, cfg_scale: float, output_type: str, attn_mode: str, block_swap: int, exclude_single_blocks: bool, use_split_attn: bool, lora_folder: str, lora1: str = "", lora2: str = "", lora3: str = "", lora4: str = "", lora1_multiplier: float = 1.0, lora2_multiplier: float = 1.0, lora3_multiplier: float = 1.0, lora4_multiplier: float = 1.0, video_path: Optional[str] = None, image_path: Optional[str] = None, strength: Optional[float] = None, negative_prompt: Optional[str] = None, embedded_cfg_scale: Optional[float] = None, split_uncond: Optional[bool] = None, guidance_scale: Optional[float] = None, use_fp8: bool = True ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate a single video with the given parameters""" global stop_event if stop_event.is_set(): yield [], "", "" return # Determine if this is a SkyReels model and what type is_skyreels = "skyreels" in model.lower() is_skyreels_i2v = is_skyreels and "i2v" in model.lower() is_skyreels_t2v = is_skyreels and "t2v" in model.lower() if is_skyreels: # Force certain parameters for SkyReels if negative_prompt is None: negative_prompt = "" if embedded_cfg_scale is None: embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels if split_uncond is None: split_uncond = True if guidance_scale is None: guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided # Determine the input channels based on model type if is_skyreels_i2v: dit_in_channels = 32 # SkyReels I2V uses 32 channels else: dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) else: dit_in_channels = 16 # Regular Hunyuan models use 16 channels embedded_cfg_scale = cfg_scale if os.path.isabs(model): model_path = model else: model_path = os.path.normpath(os.path.join(dit_folder, model)) env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" env["BATCH_RUN_ID"] = f"{time.time()}" if seed == -1: current_seed = random.randint(0, 2**32 - 1) else: batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) if batch_size > 1: # Only modify seed for batch generation current_seed = (seed + batch_id * 100003) % (2**32) else: current_seed = seed clear_cuda_cache() command = [ sys.executable, "hv_generate_video.py", "--dit", model_path, "--vae", vae, "--text_encoder1", te1, "--text_encoder2", te2, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--embedded_cfg_scale", str(cfg_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--fp8_llm", "--vae_chunk_size", "32", "--vae_spatial_tile_sample_min_size", "128" ] if use_fp8: command.append("--fp8") # Add negative prompt and embedded cfg scale for SkyReels if is_skyreels: command.extend(["--dit_in_channels", str(dit_in_channels)]) command.extend(["--guidance_scale", str(guidance_scale)]) if negative_prompt: command.extend(["--negative_prompt", negative_prompt]) if split_uncond: command.append("--split_uncond") # Add LoRA weights and multipliers if provided valid_loras = [] for weight, mult in zip([lora1, lora2, lora3, lora4], [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), mult)) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] command.extend(["--lora_weight"] + weights) command.extend(["--lora_multiplier"] + multipliers) if exclude_single_blocks: command.append("--exclude_single_blocks") if use_split_attn: command.append("--split_attn") # Handle input paths if video_path: command.extend(["--video_path", video_path]) if strength is not None: command.extend(["--strength", str(strength)]) elif image_path: command.extend(["--image_path", image_path]) # Only add strength parameter for non-SkyReels I2V models # SkyReels I2V doesn't use strength parameter for image-to-video generation if strength is not None and not is_skyreels_i2v: command.extend(["--strength", str(strength)]) print(f"{command}") p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) videos = [] while True: if stop_event.is_set(): p.terminate() p.wait() yield [], "", "Generation stopped by user." return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() clear_cuda_cache() time.sleep(0.5) # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) # Collect parameters for metadata parameters = { "prompt": prompt, "width": width, "height": height, "video_length": video_length, "fps": fps, "infer_steps": infer_steps, "seed": current_seed, "model": model, "vae": vae, "te1": te1, "te2": te2, "save_path": save_path, "flow_shift": flow_shift, "cfg_scale": cfg_scale, "output_type": output_type, "attn_mode": attn_mode, "block_swap": block_swap, "lora_weights": [lora1, lora2, lora3, lora4], "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], "input_video": video_path if video_path else None, "input_image": image_path if image_path else None, "strength": strength, "negative_prompt": negative_prompt if is_skyreels else None, "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None } add_metadata_to_video(video_path, parameters) videos.append((str(video_path), f"Seed: {current_seed}")) yield videos, f"Completed (seed: {current_seed})", "" # The issue is in the process_batch function, in the section that handles different input types # Here's the corrected version of that section: def process_batch( prompt: str, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, dit_folder: str, model: str, vae: str, te1: str, te2: str, save_path: str, flow_shift: float, cfg_scale: float, output_type: str, attn_mode: str, block_swap: int, exclude_single_blocks: bool, use_split_attn: bool, lora_folder: str, *args ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Process a batch of videos using Gradio's queue""" global stop_event stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Extract additional arguments num_lora_weights = 4 lora_weights = args[:num_lora_weights] lora_multipliers = args[num_lora_weights:num_lora_weights*2] extra_args = args[num_lora_weights*2:] # Determine if this is a SkyReels model and what type is_skyreels = "skyreels" in model.lower() is_skyreels_i2v = is_skyreels and "i2v" in model.lower() is_skyreels_t2v = is_skyreels and "t2v" in model.lower() # Handle input paths and additional parameters input_path = extra_args[0] if extra_args else None strength = float(extra_args[1]) if len(extra_args) > 1 else None # Get use_fp8 flag (it should be the last parameter) use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True # Get SkyReels specific parameters if applicable if is_skyreels: # Always set embedded_cfg_scale to 1.0 for SkyReels models embedded_cfg_scale = 1.0 negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" # Use cfg_scale for guidance_scale parameter guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale split_uncond = True if len(extra_args) > 4 and extra_args[4] else False else: negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None guidance_scale = cfg_scale embedded_cfg_scale = cfg_scale split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None for i in range(batch_size): if stop_event.is_set(): break batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Handle different input types video_path = None image_path = None if input_path: # Check if it's an image file (common image extensions) is_image = False lower_path = input_path.lower() image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') is_image = any(lower_path.endswith(ext) for ext in image_extensions) # Only use image_path for SkyReels I2V models and actual image files if is_skyreels_i2v and is_image: image_path = input_path else: video_path = input_path # Prepare arguments for process_single_video single_video_args = [ prompt, width, height, batch_size, video_length, fps, infer_steps, seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder ] single_video_args.extend(lora_weights) single_video_args.extend(lora_multipliers) single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) for videos, status, progress in process_single_video(*single_video_args): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress yield all_videos, "Batch complete", "" def update_wanx_image_dimensions(image): """Update dimensions from uploaded image""" if image is None: return "", gr.update(value=832), gr.update(value=480) img = Image.open(image) w, h = img.size w = (w // 32) * 32 h = (h // 32) * 32 return f"{w}x{h}", w, h def calculate_wanx_width(height, original_dims): """Calculate width based on height maintaining aspect ratio""" if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 32) * 32 return gr.update(value=new_width) def calculate_wanx_height(width, original_dims): """Calculate height based on width maintaining aspect ratio""" if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 32) * 32 return gr.update(value=new_height) def update_wanx_from_scale(scale, original_dims): """Update dimensions based on scale percentage""" if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 32) * 32 new_h = math.floor((orig_h * scale / 100) / 32) * 32 return gr.update(value=new_w), gr.update(value=new_h) def recommend_wanx_flow_shift(width, height): """Get recommended flow shift value based on dimensions""" recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 return gr.update(value=recommended_shift) def handle_wanx_gallery_select(evt: gr.SelectData) -> int: """Track selected index when gallery item is clicked""" return evt.index def wanx_generate_video( prompt, negative_prompt, input_image, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, attn_mode, block_swap, fp8, fp8_t5 ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate video with WanX model (supports both i2v and t2v)""" global stop_event if stop_event.is_set(): yield [], "", "" return if seed == -1: current_seed = random.randint(0, 2**32 - 1) else: current_seed = seed # Check if we need input image (required for i2v, not for t2v) if "i2v" in task and not input_image: yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" return # Prepare environment env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" clear_cuda_cache() command = [ sys.executable, "wan_generate_video.py", "--task", task, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--guidance_scale", str(guidance_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--dit", dit_path, "--vae", vae_path, "--t5", t5_path, "--sample_solver", sample_solver ] # Add image path only for i2v task and if input image is provided if "i2v" in task and input_image: command.extend(["--image_path", input_image]) command.extend(["--clip", clip_path]) # CLIP is only needed for i2v if negative_prompt: command.extend(["--negative_prompt", negative_prompt]) if fp8: command.append("--fp8") if fp8_t5: command.append("--fp8_t5") print(f"Running: {' '.join(command)}") p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) videos = [] while True: if stop_event.is_set(): p.terminate() p.wait() yield [], "", "Generation stopped by user." return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() clear_cuda_cache() time.sleep(0.5) # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) # Collect parameters for metadata parameters = { "prompt": prompt, "width": width, "height": height, "video_length": video_length, "fps": fps, "infer_steps": infer_steps, "seed": current_seed, "task": task, "flow_shift": flow_shift, "guidance_scale": guidance_scale, "output_type": output_type, "attn_mode": attn_mode, "block_swap": block_swap, "input_image": input_image if "i2v" in task else None } add_metadata_to_video(video_path, parameters) videos.append((str(video_path), f"Seed: {current_seed}")) yield videos, f"Completed (seed: {current_seed})", "" def send_wanx_to_v2v( gallery: list, prompt: str, selected_index: int, width: int, height: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, guidance_scale: float, negative_prompt: str ) -> Tuple: """Send the selected WanX video to Video2Video tab""" if not gallery or selected_index is None or selected_index >= len(gallery): return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) selected_item = gallery[selected_index] if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item if isinstance(video_path, tuple): video_path = video_path[0] return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) def wanx_generate_video_batch( prompt, negative_prompt, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, attn_mode, block_swap, fp8, fp8_t5, batch_size=1, input_image=None, # Optional for i2v lora_folder=None, *args ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate videos with WanX with support for batches and LoRA""" global stop_event stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Extract LoRA parameters from args num_loras = 4 # Fixed number of LoRA inputs lora_weights = args[:num_loras] lora_multipliers = args[num_loras:num_loras*2] exclude_single_blocks = args[num_loras*2] if len(args) > num_loras*2 else False # Process each item in the batch for i in range(batch_size): if stop_event.is_set(): yield all_videos, "Generation stopped by user", "" return # Calculate seed for this batch item current_seed = seed if seed == -1: current_seed = random.randint(0, 2**32 - 1) elif batch_size > 1: current_seed = seed + i batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Prepare command env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" command = [ sys.executable, "wan_generate_video.py", "--task", task, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--guidance_scale", str(guidance_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--dit", dit_path, "--vae", vae_path, "--t5", t5_path, "--sample_solver", sample_solver ] # Add image path if provided (for i2v) if input_image and "i2v" in task: command.extend(["--image_path", input_image]) command.extend(["--clip", clip_path]) # CLIP is needed for i2v # Add negative prompt if provided if negative_prompt: command.extend(["--negative_prompt", negative_prompt]) # Add block swap if provided if block_swap > 0: command.extend(["--blocks_to_swap", str(block_swap)]) # Add fp8 flags if enabled if fp8: command.append("--fp8") if fp8_t5: command.append("--fp8_t5") # Add LoRA parameters valid_loras = [] for j, (weight, mult) in enumerate(zip(lora_weights, lora_multipliers)): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), float(mult))) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] command.extend(["--lora_weight"] + weights) command.extend(["--lora_multiplier"] + multipliers) # Add LoRA options if exclude_single_blocks: command.append("--exclude_single_blocks") print(f"Running: {' '.join(command)}") # Execute command p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) videos = [] # Process output while True: if stop_event.is_set(): p.terminate() p.wait() yield all_videos, "Generation stopped by user", "" return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield all_videos.copy(), f"Batch {i+1}/{batch_size}: Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() # Clean CUDA cache clear_cuda_cache() time.sleep(0.5) # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_video_files = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_video_files if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) videos.append((str(video_path), f"Seed: {current_seed}")) all_videos.extend(videos) yield all_videos, "Batch complete", "" def update_wanx_t2v_dimensions(size): """Update width and height based on selected size""" width, height = map(int, size.split('*')) return gr.update(value=width), gr.update(value=height) def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: """Track selected index when gallery item is clicked""" return evt.index def send_wanx_t2v_to_v2v( gallery, prompt, selected_index, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt ) -> Tuple: """Send the selected WanX T2V video to Video2Video tab""" if not gallery or selected_index is None or selected_index >= len(gallery): return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) selected_item = gallery[selected_index] if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item if isinstance(video_path, tuple): video_path = video_path[0] return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) # UI setup with gr.Blocks( theme=themes.Default( primary_hue=colors.Color( name="custom", c50="#E6F0FF", c100="#CCE0FF", c200="#99C1FF", c300="#66A3FF", c400="#3384FF", c500="#0060df", # This is your main color c600="#0052C2", c700="#003D91", c800="#002961", c900="#001430", c950="#000A18" ) ), css=""" .gallery-item:first-child { border: 2px solid #4CAF50 !important; } .gallery-item:first-child:hover { border-color: #45a049 !important; } .green-btn { background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; color: white !important; border: none !important; } .green-btn:hover { background: linear-gradient(to bottom right, #27ae60, #219651) !important; } .refresh-btn { max-width: 40px !important; min-width: 40px !important; height: 40px !important; border-radius: 50% !important; padding: 0 !important; display: flex !important; align-items: center !important; justify-content: center !important; } """, ) as demo: # Add state for tracking selected video indices in both tabs selected_index = gr.State(value=None) # For Text to Video v2v_selected_index = gr.State(value=None) # For Video to Video params_state = gr.State() #New addition i2v_selected_index = gr.State(value=None) skyreels_selected_index = gr.State(value=None) demo.load(None, None, None, js=""" () => { document.title = 'H1111'; function updateTitle(text) { if (text && text.trim()) { const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); if (progressMatch) { const percentage = progressMatch[1]; const timeRemaining = progressMatch[2]; document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; } } } setTimeout(() => { const progressElements = document.querySelectorAll('textarea.scroll-hide'); progressElements.forEach(element => { if (element) { new MutationObserver(() => { updateTitle(element.value); }).observe(element, { attributes: true, childList: true, characterData: true }); } }); }, 1000); } """) with gr.Tabs() as tabs: # Text to Video Tab with gr.Tab(id=1, label="Text to Video"): with gr.Row(): with gr.Column(scale=4): prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) with gr.Column(scale=1): token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): generate_btn = gr.Button("Generate Video", elem_classes="green-btn") stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") with gr.Column(): with gr.Row(): video_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") with gr.Row(): refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") lora_weights = [] lora_multipliers = [] for i in range(4): with gr.Column(): lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) seed = gr.Number(label="Seed (use -1 for random)", value=-1) dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") model = gr.Dropdown( label="DiT Model", choices=get_dit_models("hunyuan"), value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): lora_folder = gr.Textbox(label="LoRA Folder", value="lora") output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) #Image to Video Tab with gr.Tab(label="Image to Video") as i2v_tab: with gr.Row(): with gr.Column(scale=4): i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) with gr.Column(scale=1): i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") i2v_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): i2v_input = gr.Image(label="Input Image", type="filepath") i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") # Scale slider as percentage scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and height inputs with gr.Row(): width = gr.Number(label="New Width", value=544, step=16) calc_height_btn = gr.Button("→") calc_width_btn = gr.Button("←") height = gr.Number(label="New Height", value=544, step=16) i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) with gr.Column(): i2v_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") # Add LoRA section for Image2Video i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") i2v_lora_weights = [] i2v_lora_multipliers = [] for i in range(4): with gr.Column(): i2v_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) i2v_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") i2v_model = gr.Dropdown( label="DiT Model", choices=get_dit_models("hunyuan"), value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") i2v_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) # Video to Video Tab with gr.Tab(id=2, label="Video to Video") as v2v_tab: with gr.Row(): with gr.Column(scale=4): v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) v2v_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt (for SkyReels models)", value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", lines=3 ) with gr.Column(scale=1): v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") v2v_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): v2v_input = gr.Video(label="Input Video", format="mp4") v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and Height Inputs with gr.Row(): v2v_width = gr.Number(label="New Width", value=544, step=16) v2v_calc_height_btn = gr.Button("→") v2v_calc_width_btn = gr.Button("←") v2v_height = gr.Number(label="New Height", value=544, step=16) v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) with gr.Column(): v2v_output = gr.Gallery( label="Generated Videos", columns=[1], rows=[1], object_fit="contain", height="auto" ) v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") v2v_lora_weights = [] v2v_lora_multipliers = [] for i in range(4): with gr.Column(): v2v_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) v2v_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") v2v_model = gr.Dropdown( label="DiT Model", choices=get_dit_models("hunyuan"), value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") v2v_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) with gr.Tab(label="SkyReels-i2v") as skyreels_tab: with gr.Row(): with gr.Column(scale=4): skyreels_prompt = gr.Textbox( scale=3, label="Enter your prompt", value="A person walking on a beach at sunset", lines=5 ) skyreels_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt", value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", lines=3 ) with gr.Column(scale=1): skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") # Scale slider as percentage skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and height inputs with gr.Row(): skyreels_width = gr.Number(label="New Width", value=544, step=16) skyreels_calc_height_btn = gr.Button("→") skyreels_calc_width_btn = gr.Button("←") skyreels_height = gr.Number(label="New Height", value=544, step=16) skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) with gr.Column(): skyreels_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") # Add LoRA section for SKYREELS skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") skyreels_lora_weights = [] skyreels_lora_multipliers = [] for i in range(4): with gr.Column(): skyreels_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) skyreels_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") skyreels_model = gr.Dropdown( label="DiT Model", choices=get_dit_models("skyreels"), value="skyreels_hunyuan_i2v_bf16.safetensors", allow_custom_value=True, interactive=True ) skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) # WanX Image to Video Tab with gr.Tab(label="WanX-i2v") as wanx_i2v_tab: with gr.Row(): with gr.Column(scale=4): wanx_prompt = gr.Textbox( scale=3, label="Enter your prompt", value="A person walking on a beach at sunset", lines=5 ) wanx_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt", value="", lines=3, info="Leave empty to use default negative prompt" ) with gr.Column(scale=1): wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") wanx_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): wanx_input = gr.Image(label="Input Image", type="filepath") wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and height display with gr.Row(): wanx_width = gr.Number(label="Width", value=832, interactive=True) wanx_calc_height_btn = gr.Button("→") wanx_calc_width_btn = gr.Button("←") wanx_height = gr.Number(label="Height", value=480, interactive=True) wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, info="Recommended: 3.0 for 480p, 5.0 for others") wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) with gr.Column(): wanx_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) wanx_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") with gr.Row(): wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") wanx_lora_weights = [] wanx_lora_multipliers = [] for i in range(4): with gr.Column(): wanx_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) wanx_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) wanx_task = gr.Dropdown( label="Task", choices=["i2v-14B"], value="i2v-14B", info="Currently only i2v-14B is supported" ) wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") wanx_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) #WanX-t2v Tab # WanX Text to Video Tab with gr.Tab(label="WanX-t2v") as wanx_t2v_tab: with gr.Row(): with gr.Column(scale=4): wanx_t2v_prompt = gr.Textbox( scale=3, label="Enter your prompt", value="A person walking on a beach at sunset", lines=5 ) wanx_t2v_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt", value="", lines=3, info="Leave empty to use default negative prompt" ) with gr.Column(scale=1): wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): with gr.Row(): wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, info="Recommended: 3.0 for I2V with 480p, 5.0 for others") wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) with gr.Column(): wanx_t2v_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") with gr.Row(): wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") wanx_t2v_lora_weights = [] wanx_t2v_lora_multipliers = [] for i in range(4): with gr.Column(): wanx_t2v_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) wanx_t2v_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) wanx_t2v_task = gr.Dropdown( label="Task", choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], value="t2v-14B", info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" ) wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, info="Max 39 for 14B model, 29 for 1.3B model") wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) #Video Info Tab with gr.Tab("Video Info") as video_info_tab: with gr.Row(): video_input = gr.Video(label="Upload Video", interactive=True) metadata_output = gr.JSON(label="Generation Parameters") with gr.Row(): send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") with gr.Row(): status = gr.Textbox(label="Status", interactive=False) #Merge Model's tab with gr.Tab("Convert LoRA") as convert_lora_tab: def suggest_output_name(file_obj) -> str: """Generate suggested output name from input file""" if not file_obj: return "" # Get input filename without extension and add MUSUBI base_name = os.path.splitext(os.path.basename(file_obj.name))[0] return f"{base_name}_MUSUBI" def convert_lora(input_file, output_name: str, target_format: str) -> str: """Convert LoRA file to specified format""" try: if not input_file: return "Error: No input file selected" # Ensure output directory exists os.makedirs("lora", exist_ok=True) # Construct output path output_path = os.path.join("lora", f"{output_name}.safetensors") # Build command cmd = [ sys.executable, "convert_lora.py", "--input", input_file.name, "--output", output_path, "--target", target_format ] print(f"Converting {input_file.name} to {output_path}") # Execute conversion result = subprocess.run( cmd, capture_output=True, text=True, check=True ) if os.path.exists(output_path): return f"Successfully converted LoRA to {output_path}" else: return "Error: Output file not created" except subprocess.CalledProcessError as e: return f"Error during conversion: {e.stderr}" except Exception as e: return f"Error: {str(e)}" with gr.Row(): input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") format_radio = gr.Radio( choices=["default", "other"], value="default", label="Target Format", info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" ) with gr.Row(): convert_btn = gr.Button("Convert LoRA", variant="primary") status_output = gr.Textbox(label="Status", interactive=False) # Automatically update output name when file is selected input_file.change( fn=suggest_output_name, inputs=[input_file], outputs=[output_name] ) # Handle conversion convert_btn.click( fn=convert_lora, inputs=[input_file, output_name, format_radio], outputs=status_output ) with gr.Tab("Model Merging") as model_merge_tab: with gr.Row(): with gr.Column(): # Model selection dit_model = gr.Dropdown( label="Base DiT Model", choices=["mp_rank_00_model_states.pt"], value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") with gr.Row(): with gr.Column(): # Output model name output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) merge_btn = gr.Button("Merge Models", variant="primary") merge_status = gr.Textbox(label="Status", interactive=False) with gr.Row(): # LoRA selection section (similar to Text2Video) merge_lora_weights = [] merge_lora_multipliers = [] for i in range(4): with gr.Column(): merge_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) merge_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") #text to video def change_to_tab_one(): return gr.Tabs(selected=1) #This will navigate #video to video def change_to_tab_two(): return gr.Tabs(selected=2) #This will navigate def change_to_skyreels_tab(): return gr.Tabs(selected=3) #SKYREELS TAB!!! # Add state management for dimensions def sync_skyreels_dimensions(width, height): return gr.update(value=width), gr.update(value=height) # Add this function to update the LoRA dropdowns in the SKYREELS tab def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: new_choices = get_lora_options(lora_folder) weights = current_values[:4] multipliers = current_values[4:8] results = [] for i in range(4): weight = weights[i] if i < len(weights) else "None" multiplier = multipliers[i] if i < len(multipliers) else 1.0 if weight not in new_choices: weight = "None" results.extend([ gr.update(choices=new_choices, value=weight), gr.update(value=multiplier) ]) return results # Add this function to update the models dropdown in the SKYREELS tab def update_skyreels_model_dropdown(dit_folder: str) -> Dict: models = get_dit_models(dit_folder) return gr.update(choices=models, value=models[0] if models else None) # Add event handler for model dropdown refresh skyreels_dit_folder.change( fn=update_skyreels_model_dropdown, inputs=[skyreels_dit_folder], outputs=[skyreels_model] ) # Add handlers for the refresh button skyreels_refresh_btn.click( fn=update_skyreels_lora_dropdowns, inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] ) # Skyreels dimension handling def calculate_skyreels_width(height, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 16) * 16 return gr.update(value=new_width) def calculate_skyreels_height(width, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 16) * 16 return gr.update(value=new_height) def update_skyreels_from_scale(scale, original_dims): if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 16) * 16 new_h = math.floor((orig_h * scale / 100) / 16) * 16 return gr.update(value=new_w), gr.update(value=new_h) def update_skyreels_dimensions(image): if image is None: return "", gr.update(value=544), gr.update(value=544) img = Image.open(image) w, h = img.size w = (w // 16) * 16 h = (h // 16) * 16 return f"{w}x{h}", w, h def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: return evt.index def send_skyreels_to_v2v( gallery: list, prompt: str, selected_index: int, width: int, height: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, cfg_scale: float, lora1: str, lora2: str, lora3: str, lora4: str, lora1_multiplier: float, lora2_multiplier: float, lora3_multiplier: float, lora4_multiplier: float, negative_prompt: str = "" # Add this parameter ) -> Tuple: if not gallery or selected_index is None or selected_index >= len(gallery): return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, negative_prompt) # Add negative_prompt to return selected_item = gallery[selected_index] if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item if isinstance(video_path, tuple): video_path = video_path[0] return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, negative_prompt) # Add negative_prompt to return # Add event handlers for the SKYREELS tab skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) # Image input handling skyreels_input.change( fn=update_skyreels_dimensions, inputs=[skyreels_input], outputs=[skyreels_original_dims, skyreels_width, skyreels_height] ) skyreels_scale_slider.change( fn=update_skyreels_from_scale, inputs=[skyreels_scale_slider, skyreels_original_dims], outputs=[skyreels_width, skyreels_height] ) skyreels_calc_width_btn.click( fn=calculate_skyreels_width, inputs=[skyreels_height, skyreels_original_dims], outputs=[skyreels_width] ) skyreels_calc_height_btn.click( fn=calculate_skyreels_height, inputs=[skyreels_width, skyreels_original_dims], outputs=[skyreels_height] ) # SKYREELS tab generator button handler skyreels_generate_btn.click( fn=process_batch, inputs=[ skyreels_prompt, skyreels_width, skyreels_height, skyreels_batch_size, skyreels_video_length, skyreels_fps, skyreels_infer_steps, skyreels_seed, skyreels_dit_folder, skyreels_model, skyreels_vae, skyreels_te1, skyreels_te2, skyreels_save_path, skyreels_flow_shift, skyreels_embedded_cfg_scale, skyreels_output_type, skyreels_attn_mode, skyreels_block_swap, skyreels_exclude_single_blocks, skyreels_use_split_attn, skyreels_lora_folder, *skyreels_lora_weights, *skyreels_lora_multipliers, skyreels_input, skyreels_strength, skyreels_negative_prompt, skyreels_guidance_scale, skyreels_split_uncond, skyreels_use_fp8 ], outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[skyreels_batch_size], outputs=skyreels_selected_index ) # Gallery selection handling skyreels_output.select( fn=handle_skyreels_gallery_select, outputs=skyreels_selected_index ) # Send to Video2Video handler skyreels_send_to_v2v_btn.click( fn=send_skyreels_to_v2v, inputs=[ skyreels_output, skyreels_prompt, skyreels_selected_index, skyreels_width, skyreels_height, skyreels_video_length, skyreels_fps, skyreels_infer_steps, skyreels_seed, skyreels_flow_shift, skyreels_guidance_scale ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) # Refresh button handler skyreels_refresh_outputs = [skyreels_model] for i in range(4): skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) skyreels_refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, outputs=skyreels_refresh_outputs ) # Add skyreels_selected_index to the initial states at the beginning of the script skyreels_selected_index = gr.State(value=None) # Add this with other state declarations def calculate_v2v_width(height, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 return gr.update(value=new_width) def calculate_v2v_height(width, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 return gr.update(value=new_height) def update_v2v_from_scale(scale, original_dims): if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 return gr.update(value=new_w), gr.update(value=new_h) def update_v2v_dimensions(video): if video is None: return "", gr.update(value=544), gr.update(value=544) cap = cv2.VideoCapture(video) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() # Make dimensions divisible by 16 w = (w // 16) * 16 h = (h // 16) * 16 return f"{w}x{h}", w, h # Event Handlers for Video to Video Tab v2v_input.change( fn=update_v2v_dimensions, inputs=[v2v_input], outputs=[v2v_original_dims, v2v_width, v2v_height] ) v2v_scale_slider.change( fn=update_v2v_from_scale, inputs=[v2v_scale_slider, v2v_original_dims], outputs=[v2v_width, v2v_height] ) v2v_calc_width_btn.click( fn=calculate_v2v_width, inputs=[v2v_height, v2v_original_dims], outputs=[v2v_width] ) v2v_calc_height_btn.click( fn=calculate_v2v_height, inputs=[v2v_width, v2v_original_dims], outputs=[v2v_height] ) ##Image 2 video dimension logic def calculate_width(height, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 return gr.update(value=new_width) def calculate_height(width, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 return gr.update(value=new_height) def update_from_scale(scale, original_dims): if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 return gr.update(value=new_w), gr.update(value=new_h) def update_dimensions(image): if image is None: return "", gr.update(value=544), gr.update(value=544) img = Image.open(image) w, h = img.size # Make dimensions divisible by 16 w = (w // 16) * 16 # Changed from 8 to 16 h = (h // 16) * 16 # Changed from 8 to 16 return f"{w}x{h}", w, h i2v_input.change( fn=update_dimensions, inputs=[i2v_input], outputs=[original_dims, width, height] ) scale_slider.change( fn=update_from_scale, inputs=[scale_slider, original_dims], outputs=[width, height] ) calc_width_btn.click( fn=calculate_width, inputs=[height, original_dims], outputs=[width] ) calc_height_btn.click( fn=calculate_height, inputs=[width, original_dims], outputs=[height] ) # Function to get available DiT models def get_dit_models(dit_folder: str) -> List[str]: if not os.path.exists(dit_folder): return ["mp_rank_00_model_states.pt"] models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] models.sort(key=str.lower) return models if models else ["mp_rank_00_model_states.pt"] # Function to perform model merging def merge_models( dit_folder: str, dit_model: str, output_model: str, exclude_single_blocks: bool, merge_lora_folder: str, *lora_params # Will contain both weights and multipliers ) -> str: try: # Separate weights and multipliers num_loras = len(lora_params) // 2 weights = list(lora_params[:num_loras]) multipliers = list(lora_params[num_loras:]) # Filter out "None" selections valid_loras = [] for weight, mult in zip(weights, multipliers): if weight and weight != "None": valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) if not valid_loras: return "No LoRA models selected for merging" # Create output path in the dit folder os.makedirs(dit_folder, exist_ok=True) output_path = os.path.join(dit_folder, output_model) # Prepare command cmd = [ sys.executable, "merge_lora.py", "--dit", os.path.join(dit_folder, dit_model), "--save_merged_model", output_path ] # Add LoRA weights and multipliers weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] cmd.extend(["--lora_weight"] + weights) cmd.extend(["--lora_multiplier"] + multipliers) if exclude_single_blocks: cmd.append("--exclude_single_blocks") # Execute merge operation result = subprocess.run( cmd, capture_output=True, text=True, check=True ) if os.path.exists(output_path): return f"Successfully merged model and saved to {output_path}" else: return "Error: Output file not created" except subprocess.CalledProcessError as e: return f"Error during merging: {e.stderr}" except Exception as e: return f"Error: {str(e)}" # Update DiT model dropdown def update_dit_dropdown(dit_folder: str) -> Dict: models = get_dit_models(dit_folder) return gr.update(choices=models, value=models[0] if models else None) # Connect events merge_btn.click( fn=merge_models, inputs=[ dit_folder, dit_model, output_model, exclude_single_blocks, merge_lora_folder, *merge_lora_weights, *merge_lora_multipliers ], outputs=merge_status ) # Refresh buttons for both DiT and LoRA dropdowns merge_refresh_btn.click( fn=lambda f: update_dit_dropdown(f), inputs=[dit_folder], outputs=[dit_model] ) # LoRA refresh handling merge_refresh_outputs = [] for i in range(4): merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) merge_refresh_btn.click( fn=update_lora_dropdowns, inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, outputs=merge_refresh_outputs ) # Event handlers prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) stop_btn.click(fn=lambda: stop_event.set(), queue=False) v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) #Image_to_Video def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters img = Image.open(image_path) # Resize to the specified dimensions img_resized = img.resize((width, height), Image.LANCZOS) temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") img_resized.save(temp_image_path) # Rest of function remains the same frame_rate = 24 duration = frames / frame_rate command = [ "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", "-t", str(duration), "-pix_fmt", "yuv420p", "-vf", f"fps={frame_rate}", output_path ] try: subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) print(f"Video saved to {output_path}") return True except subprocess.CalledProcessError as e: print(f"An error occurred while creating the video: {e}") return False finally: # Clean up the temporary image file if os.path.exists(temp_image_path): os.remove(temp_image_path) img.close() # Make sure to close the image file explicitly def generate_from_image( image_path, prompt, width, height, video_length, fps, infer_steps, seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder, strength, batch_size, *lora_params ): """Generate video from input image with progressive updates""" global stop_event stop_event.clear() # Create temporary video path temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") try: # Convert image to video if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): yield [], "Failed to create temporary video", "Error in video creation" return # Ensure video is fully written before proceeding time.sleep(1) if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: yield [], "Failed to create temporary video", "Temporary video file is empty or missing" return # Get video dimensions try: probe = ffmpeg.probe(temp_video_path) video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) if video_stream is None: raise ValueError("No video stream found") width = int(video_stream['width']) height = int(video_stream['height']) except Exception as e: yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" return # Generate the video using the temporary file try: generator = process_single_video( prompt, width, height, batch_size, video_length, fps, infer_steps, seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder, *lora_params, video_path=temp_video_path, strength=strength ) # Forward all generator updates for videos, batch_text, progress_text in generator: yield videos, batch_text, progress_text except Exception as e: yield [], f"Error in video generation: {str(e)}", "Generation error" return except Exception as e: yield [], f"Unexpected error: {str(e)}", "Error occurred" return finally: # Clean up temporary file try: if os.path.exists(temp_video_path): os.remove(temp_video_path) except Exception: pass # Ignore cleanup errors # Add event handlers i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) def handle_i2v_gallery_select(evt: gr.SelectData) -> int: """Track selected index when I2V gallery item is clicked""" return evt.index def send_i2v_to_v2v( gallery: list, prompt: str, selected_index: int, width: int, height: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, cfg_scale: float, lora1: str, lora2: str, lora3: str, lora4: str, lora1_multiplier: float, lora2_multiplier: float, lora3_multiplier: float, lora4_multiplier: float ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: """Send the selected video and parameters from Image2Video tab to Video2Video tab""" if not gallery or selected_index is None or selected_index >= len(gallery): return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier selected_item = gallery[selected_index] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] # Use the original width and height without doubling return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) # Generate button handler i2v_generate_btn.click( fn=process_batch, inputs=[ i2v_prompt, width, height, i2v_batch_size, i2v_video_length, i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_use_fp8 ], outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[i2v_batch_size], outputs=i2v_selected_index ) # Send to Video2Video i2v_output.select( fn=handle_i2v_gallery_select, outputs=i2v_selected_index ) i2v_send_to_v2v_btn.click( fn=send_i2v_to_v2v, inputs=[ i2v_output, i2v_prompt, i2v_selected_index, width, height, i2v_video_length, i2v_fps, i2v_infer_steps, i2v_seed, i2v_flow_shift, i2v_cfg_scale ] + i2v_lora_weights + i2v_lora_multipliers, outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale ] + v2v_lora_weights + v2v_lora_multipliers ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) #Video Info def clean_video_path(video_path) -> str: """Extract clean video path from Gradio's various return formats""" print(f"Input video_path: {video_path}, type: {type(video_path)}") if isinstance(video_path, dict): path = video_path.get("name", "") elif isinstance(video_path, (tuple, list)): path = video_path[0] elif isinstance(video_path, str): path = video_path else: path = "" print(f"Cleaned path: {path}") return path def handle_video_upload(video_path: str) -> Dict: """Handle video upload and metadata extraction""" if not video_path: return {}, "No video uploaded" metadata = extract_video_metadata(video_path) if not metadata: return {}, "No metadata found in video" return metadata, "Metadata extracted successfully" def get_video_info(video_path: str) -> dict: try: probe = ffmpeg.probe(video_path) video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') width = int(video_info['width']) height = int(video_info['height']) fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 # Calculate total frames duration = float(probe['format']['duration']) total_frames = int(duration * fps) # Ensure video length does not exceed 201 frames if total_frames > 201: total_frames = 201 duration = total_frames / fps # Adjust duration accordingly return { 'width': width, 'height': height, 'fps': fps, 'total_frames': total_frames, 'duration': duration # Might be useful in some contexts } except Exception as e: print(f"Error extracting video info: {e}") return {} def extract_video_details(video_path: str) -> Tuple[dict, str]: metadata = extract_video_metadata(video_path) video_details = get_video_info(video_path) # Combine metadata with video details for key, value in video_details.items(): if key not in metadata: metadata[key] = value # Ensure video length does not exceed 201 frames if 'video_length' in metadata: metadata['video_length'] = min(metadata['video_length'], 201) else: metadata['video_length'] = min(video_details.get('total_frames', 0), 201) # Return both the updated metadata and a status message return metadata, "Video details extracted successfully" def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: """Create parameter mapping for target tab""" if not metadata: return "No parameters to send", {} tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" try: mapping = create_parameter_transfer_map(metadata, target_tab) return f"Parameters ready for {tab_name}", mapping except Exception as e: return f"Error: {str(e)}", {} video_input.upload( fn=extract_video_details, inputs=video_input, outputs=[metadata_output, status] ) send_to_t2v_btn.click( fn=lambda m: send_parameters_to_tab(m, "t2v"), inputs=metadata_output, outputs=[status, params_state] ).then( fn=change_to_tab_one, inputs=None, outputs=[tabs] ).then( lambda params: [ params.get("prompt", ""), params.get("width", 544), params.get("height", 544), params.get("batch_size", 1), params.get("video_length", 25), params.get("fps", 24), params.get("infer_steps", 30), params.get("seed", -1), params.get("model", "hunyuan/mp_rank_00_model_states.pt"), params.get("vae", "hunyuan/pytorch_model.pt"), params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), params.get("te2", "hunyuan/clip_l.safetensors"), params.get("save_path", "outputs"), params.get("flow_shift", 11.0), params.get("cfg_scale", 7.0), params.get("output_type", "video"), params.get("attn_mode", "sdpa"), params.get("block_swap", "0"), *[params.get(f"lora{i+1}", "") for i in range(4)], *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] ] if params else [gr.update()]*26, inputs=params_state, outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap] + lora_weights + lora_multipliers ) # Text to Video generation generate_btn.click( fn=process_batch, inputs=[ prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 ], outputs=[video_output, batch_progress, progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[batch_size], outputs=selected_index ) # Update gallery selection handling def handle_gallery_select(evt: gr.SelectData) -> int: return evt.index # Track selected index when gallery item is clicked video_output.select( fn=handle_gallery_select, outputs=selected_index ) # Track selected index when Video2Video gallery item is clicked def handle_v2v_gallery_select(evt: gr.SelectData) -> int: """Handle gallery selection without automatically updating the input""" return evt.index # Update the gallery selection event v2v_output.select( fn=handle_v2v_gallery_select, outputs=v2v_selected_index ) # Send button handler with gallery selection def handle_send_button( gallery: list, prompt: str, idx: int, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, cfg_scale: float, lora1: str, lora2: str, lora3: str, lora4: str, lora1_multiplier: float, lora2_multiplier: float, lora3_multiplier: float, lora4_multiplier: float ) -> tuple: if not gallery or idx is None or idx >= len(gallery): return (None, "", width, height, batch_size, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, "") # Add empty string for negative_prompt in the return values # Auto-select first item if only one exists and no selection made if idx is None and len(gallery) == 1: idx = 0 selected_item = gallery[idx] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] return ( str(video_path), prompt, width, height, batch_size, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, "" # Add empty string for negative_prompt ) send_t2v_to_v2v_btn.click( fn=handle_send_button, inputs=[ video_output, prompt, selected_index, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, seed, flow_shift, cfg_scale ] + lora_weights + lora_multipliers, # Remove the string here outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: """Handle both parameters and video transfer""" status_msg, params = send_parameters_to_tab(metadata, "v2v") return status_msg, params, video_path def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: """Handle both parameters and video transfer from Video Info to V2V tab""" if not video_path: return "No video selected", {}, None status_msg, params = send_parameters_to_tab(metadata, "v2v") # Just return the path directly return status_msg, params, video_path # Send button click handler send_to_v2v_btn.click( fn=handle_info_to_v2v, inputs=[metadata_output, video_input], outputs=[status, params_state, v2v_input] ).then( lambda params: [ params.get("v2v_prompt", ""), params.get("v2v_width", 544), params.get("v2v_height", 544), params.get("v2v_batch_size", 1), params.get("v2v_video_length", 25), params.get("v2v_fps", 24), params.get("v2v_infer_steps", 30), params.get("v2v_seed", -1), params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), params.get("v2v_vae", "hunyuan/pytorch_model.pt"), params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), params.get("v2v_te2", "hunyuan/clip_l.safetensors"), params.get("v2v_save_path", "outputs"), params.get("v2v_flow_shift", 11.0), params.get("v2v_cfg_scale", 7.0), params.get("v2v_output_type", "video"), params.get("v2v_attn_mode", "sdpa"), params.get("v2v_block_swap", "0"), *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] ] if params else [gr.update()] * 26, inputs=params_state, outputs=[ v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, v2v_block_swap ] + v2v_lora_weights + v2v_lora_multipliers ).then( lambda: print(f"Tabs object: {tabs}"), # Debug print outputs=None ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) # Handler for sending selected video from Video2Video gallery to input def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: """Send the currently selected video in V2V gallery to V2V input""" if not gallery or idx is None or idx >= len(gallery): return None, "" selected_item = gallery[idx] video_path = None # Handle different gallery item formats if isinstance(selected_item, tuple): video_path = selected_item[0] # Gallery returns (path, caption) elif isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, str): video_path = selected_item if not video_path: return None, "" # Check if the file exists and is accessible if not os.path.exists(video_path): print(f"Warning: Video file not found at {video_path}") return None, "" return video_path, prompt v2v_send_to_input_btn.click( fn=handle_v2v_send_button, inputs=[v2v_output, v2v_prompt, v2v_selected_index], outputs=[v2v_input, v2v_prompt] ).then( lambda: gr.update(visible=True), # Ensure the video input is visible outputs=v2v_input ) # Video to Video generation v2v_generate_btn.click( fn=process_batch, inputs=[ v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 ], outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[v2v_batch_size], outputs=v2v_selected_index ) refresh_outputs = [model] # Add model dropdown to outputs for i in range(4): refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, outputs=refresh_outputs ) # Image2Video refresh i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs for i in range(4): i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) i2v_refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, outputs=i2v_refresh_outputs ) # Video2Video refresh v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs for i in range(4): v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) v2v_refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, outputs=v2v_refresh_outputs ) # WanX-i2v tab connections wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) # Image input handling for WanX-i2v wanx_input.change( fn=update_wanx_image_dimensions, inputs=[wanx_input], outputs=[wanx_original_dims, wanx_width, wanx_height] ) # Scale slider handling for WanX-i2v wanx_scale_slider.change( fn=update_wanx_from_scale, inputs=[wanx_scale_slider, wanx_original_dims], outputs=[wanx_width, wanx_height] ) # Width/height calculation buttons for WanX-i2v wanx_calc_width_btn.click( fn=calculate_wanx_width, inputs=[wanx_height, wanx_original_dims], outputs=[wanx_width] ) wanx_calc_height_btn.click( fn=calculate_wanx_height, inputs=[wanx_width, wanx_original_dims], outputs=[wanx_height] ) # Flow shift recommendation buttons wanx_recommend_flow_btn.click( fn=recommend_wanx_flow_shift, inputs=[wanx_width, wanx_height], outputs=[wanx_flow_shift] ) wanx_t2v_recommend_flow_btn.click( fn=recommend_wanx_flow_shift, inputs=[wanx_t2v_width, wanx_t2v_height], outputs=[wanx_t2v_flow_shift] ) # Generate button handler wanx_generate_btn.click( fn=wanx_generate_video_batch, inputs=[ wanx_prompt, wanx_negative_prompt, wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_flow_shift, wanx_guidance_scale, wanx_seed, wanx_task, wanx_dit_path, wanx_vae_path, wanx_t5_path, wanx_clip_path, wanx_save_path, wanx_output_type, wanx_sample_solver, wanx_attn_mode, wanx_block_swap, wanx_fp8, wanx_fp8_t5, wanx_batch_size, wanx_input, # Image input wanx_lora_folder, *wanx_lora_weights, *wanx_lora_multipliers, wanx_exclude_single_blocks ], outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[wanx_batch_size], outputs=skyreels_selected_index ) # Gallery selection handling wanx_output.select( fn=handle_wanx_gallery_select, outputs=skyreels_selected_index # Reuse the skyreels_selected_index ) # Send to Video2Video handler wanx_send_to_v2v_btn.click( fn=send_wanx_to_v2v, inputs=[ wanx_output, wanx_prompt, skyreels_selected_index, # Reuse the skyreels_selected_index wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_seed, wanx_flow_shift, wanx_guidance_scale, wanx_negative_prompt ], outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale, v2v_negative_prompt ] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) # Add state for T2V tab selected index wanx_t2v_selected_index = gr.State(value=None) # Connect prompt token counter wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) # Stop button handler wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) # Flow shift recommendation button wanx_t2v_recommend_flow_btn.click( fn=recommend_wanx_flow_shift, inputs=[wanx_t2v_width, wanx_t2v_height], outputs=[wanx_t2v_flow_shift] ) # Task change handler to update CLIP visibility and path def update_clip_visibility(task): is_i2v = "i2v" in task return gr.update(visible=is_i2v) wanx_t2v_task.change( fn=update_clip_visibility, inputs=[wanx_t2v_task], outputs=[wanx_t2v_clip_path] ) # Generate button handler for T2V wanx_t2v_generate_btn.click( fn=wanx_generate_video_batch, inputs=[ wanx_t2v_prompt, wanx_t2v_negative_prompt, wanx_t2v_width, wanx_t2v_height, wanx_t2v_video_length, wanx_t2v_fps, wanx_t2v_infer_steps, wanx_t2v_flow_shift, wanx_t2v_guidance_scale, wanx_t2v_seed, wanx_t2v_task, wanx_t2v_dit_path, wanx_t2v_vae_path, wanx_t2v_t5_path, wanx_t2v_clip_path, wanx_t2v_save_path, wanx_t2v_output_type, wanx_t2v_sample_solver, wanx_t2v_attn_mode, wanx_t2v_block_swap, wanx_t2v_fp8, wanx_t2v_fp8_t5, wanx_t2v_batch_size, wanx_t2v_lora_folder, *wanx_t2v_lora_weights, *wanx_t2v_lora_multipliers, wanx_t2v_exclude_single_blocks ], outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[wanx_t2v_batch_size], outputs=wanx_t2v_selected_index ) # Gallery selection handling wanx_t2v_output.select( fn=handle_wanx_t2v_gallery_select, outputs=wanx_t2v_selected_index ) # Send to Video2Video handler wanx_t2v_send_to_v2v_btn.click( fn=send_wanx_t2v_to_v2v, inputs=[ wanx_t2v_output, wanx_t2v_prompt, wanx_t2v_selected_index, wanx_t2v_width, wanx_t2v_height, wanx_t2v_video_length, wanx_t2v_fps, wanx_t2v_infer_steps, wanx_t2v_seed, wanx_t2v_flow_shift, wanx_t2v_guidance_scale, wanx_t2v_negative_prompt ], outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale, v2v_negative_prompt ] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) # Refresh handlers for WanX-i2v wanx_refresh_outputs = [] for i in range(4): wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) wanx_refresh_btn.click( fn=update_lora_dropdowns, inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, outputs=wanx_refresh_outputs ) # Refresh handlers for WanX-t2v wanx_t2v_refresh_outputs = [] for i in range(4): wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) wanx_t2v_refresh_btn.click( fn=update_lora_dropdowns, inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, outputs=wanx_t2v_refresh_outputs ) demo.queue().launch(server_name="0.0.0.0", share=False)