import gradio as gr from gradio import update as gr_update import subprocess import threading import time import re import os import random import tiktoken import sys import ffmpeg from typing import List, Tuple, Optional, Generator, Dict import json from gradio import themes from gradio.themes.utils import colors import subprocess from PIL import Image import math import cv2 import glob import shutil from pathlib import Path import logging from datetime import datetime from tqdm import tqdm # Add global stop event stop_event = threading.Event() logger = logging.getLogger(__name__) def process_hunyuani2v_video( prompt: str, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, dit_folder: str, model: str, vae: str, te1: str, te2: str, save_path: str, flow_shift: float, cfg_scale: float, output_type: str, attn_mode: str, block_swap: int, exclude_single_blocks: bool, use_split_attn: bool, lora_folder: str, lora1: str = "", lora2: str = "", lora3: str = "", lora4: str = "", lora1_multiplier: float = 1.0, lora2_multiplier: float = 1.0, lora3_multiplier: float = 1.0, lora4_multiplier: float = 1.0, video_path: Optional[str] = None, image_path: Optional[str] = None, strength: Optional[float] = None, negative_prompt: Optional[str] = None, embedded_cfg_scale: Optional[float] = None, split_uncond: Optional[bool] = None, guidance_scale: Optional[float] = None, use_fp8: bool = True, clip_vision_path: Optional[str] = None, i2v_stability: bool = False, fp8_fast: bool = False, compile_model: bool = False, compile_backend: str = "inductor", compile_mode: str = "max-autotune-no-cudagraphs", compile_dynamic: bool = False, compile_fullgraph: bool = False ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate a single video with the hunyuani2v script with updated parameters""" global stop_event if stop_event.is_set(): yield [], "", "" return # Determine if this is a SkyReels model and what type is_skyreels = "skyreels" in model.lower() is_skyreels_i2v = is_skyreels and "i2v" in model.lower() is_skyreels_t2v = is_skyreels and "t2v" in model.lower() # Set defaults for hunyuani2v specific parameters if is_skyreels: # Force certain parameters for SkyReels if negative_prompt is None: negative_prompt = "" if embedded_cfg_scale is None: embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels if split_uncond is None: split_uncond = True if guidance_scale is None: guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided else: embedded_cfg_scale = cfg_scale if os.path.isabs(model): model_path = model else: model_path = os.path.normpath(os.path.join(dit_folder, model)) env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" env["BATCH_RUN_ID"] = f"{time.time()}" if seed == -1: current_seed = random.randint(0, 2**32 - 1) else: batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) if batch_size > 1: # Only modify seed for batch generation current_seed = (seed + batch_id * 100003) % (2**32) else: current_seed = seed clear_cuda_cache() # Now use hv_generate_video_with_hunyuani2v.py instead command = [ sys.executable, "hv_generate_video_with_hunyuani2v.py", "--dit", model_path, "--vae", vae, "--text_encoder1", te1, "--text_encoder2", te2, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--embedded_cfg_scale", str(cfg_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--fp8_llm", "--vae_chunk_size", "32", "--vae_spatial_tile_sample_min_size", "128" ] if use_fp8: command.append("--fp8") # Add new parameters specific to hunyuani2v script if clip_vision_path: command.extend(["--clip_vision_path", clip_vision_path]) if i2v_stability: command.append("--i2v_stability") if fp8_fast: command.append("--fp8_fast") if compile_model: command.append("--compile") command.extend([ "--compile_args", compile_backend, compile_mode, str(compile_dynamic).lower(), str(compile_fullgraph).lower() ]) # Add negative prompt and embedded cfg scale command.extend(["--guidance_scale", str(guidance_scale)]) if negative_prompt: command.extend(["--negative_prompt", negative_prompt]) if split_uncond: command.append("--split_uncond") # Add LoRA weights and multipliers if provided valid_loras = [] for weight, mult in zip([lora1, lora2, lora3, lora4], [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), mult)) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] command.extend(["--lora_weight"] + weights) command.extend(["--lora_multiplier"] + multipliers) if exclude_single_blocks: command.append("--exclude_single_blocks") if use_split_attn: command.append("--split_attn") # Handle input paths if video_path: command.extend(["--video_path", video_path]) if strength is not None: command.extend(["--strength", str(strength)]) elif image_path: command.extend(["--image_path", image_path]) # Only add strength parameter for non-SkyReels I2V models # SkyReels I2V doesn't use strength parameter for image-to-video generation if strength is not None and not is_skyreels_i2v: command.extend(["--strength", str(strength)]) print(f"{command}") p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) videos = [] while True: if stop_event.is_set(): p.terminate() p.wait() yield [], "", "Generation stopped by user." return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() clear_cuda_cache() time.sleep(0.5) # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) # Collect parameters for metadata parameters = { "prompt": prompt, "width": width, "height": height, "video_length": video_length, "fps": fps, "infer_steps": infer_steps, "seed": current_seed, "model": model, "vae": vae, "te1": te1, "te2": te2, "save_path": save_path, "flow_shift": flow_shift, "cfg_scale": cfg_scale, "output_type": output_type, "attn_mode": attn_mode, "block_swap": block_swap, "lora_weights": [lora1, lora2, lora3, lora4], "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], "input_video": video_path if video_path else None, "input_image": image_path if image_path else None, "strength": strength, "negative_prompt": negative_prompt, "embedded_cfg_scale": embedded_cfg_scale, "clip_vision_path": clip_vision_path, "i2v_stability": i2v_stability, "fp8_fast": fp8_fast, "compile_model": compile_model } add_metadata_to_video(video_path, parameters) videos.append((str(video_path), f"Seed: {current_seed}")) yield videos, f"Completed (seed: {current_seed})", "" # Now let's create a new batch processing function that uses the hunyuani2v function def process_hunyuani2v_batch( prompt: str, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, dit_folder: str, model: str, vae: str, te1: str, te2: str, save_path: str, flow_shift: float, cfg_scale: float, output_type: str, attn_mode: str, block_swap: int, exclude_single_blocks: bool, use_split_attn: bool, lora_folder: str, *args ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Process a batch of videos using the hunyuani2v script""" global stop_event stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Extract additional arguments num_lora_weights = 4 lora_weights = args[:num_lora_weights] lora_multipliers = args[num_lora_weights:num_lora_weights*2] # New parameters for hunyuani2v # Base parameter list index after lora weights and multipliers base_idx = num_lora_weights*2 # Extract parameters input_path = args[base_idx] if len(args) > base_idx else None strength = float(args[base_idx+1]) if len(args) > base_idx+1 and args[base_idx+1] is not None else None negative_prompt = str(args[base_idx+2]) if len(args) > base_idx+2 and args[base_idx+2] is not None else None guidance_scale = float(args[base_idx+3]) if len(args) > base_idx+3 and args[base_idx+3] is not None else cfg_scale split_uncond = bool(args[base_idx+4]) if len(args) > base_idx+4 else None use_fp8 = bool(args[base_idx+5]) if len(args) > base_idx+5 else True # New hunyuani2v parameters clip_vision_path = str(args[base_idx+6]) if len(args) > base_idx+6 and args[base_idx+6] is not None else None i2v_stability = bool(args[base_idx+7]) if len(args) > base_idx+7 else False fp8_fast = bool(args[base_idx+8]) if len(args) > base_idx+8 else False compile_model = bool(args[base_idx+9]) if len(args) > base_idx+9 else False compile_backend = str(args[base_idx+10]) if len(args) > base_idx+10 and args[base_idx+10] is not None else "inductor" compile_mode = str(args[base_idx+11]) if len(args) > base_idx+11 and args[base_idx+11] is not None else "max-autotune-no-cudagraphs" compile_dynamic = bool(args[base_idx+12]) if len(args) > base_idx+12 else False compile_fullgraph = bool(args[base_idx+13]) if len(args) > base_idx+13 else False embedded_cfg_scale = cfg_scale for i in range(batch_size): if stop_event.is_set(): break batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Handle different input types video_path = None image_path = None if input_path: is_image = False lower_path = input_path.lower() image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') is_image = any(lower_path.endswith(ext) for ext in image_extensions) if is_image: image_path = input_path else: video_path = input_path # Prepare arguments for process_hunyuani2v_video current_seed = seed + i if seed != -1 and batch_size > 1 else seed if seed != -1 else -1 hunyuani2v_args = [ prompt, width, height, batch_size, video_length, fps, infer_steps, current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder ] hunyuani2v_args.extend(lora_weights) hunyuani2v_args.extend(lora_multipliers) hunyuani2v_args.extend([ video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8, clip_vision_path, i2v_stability, fp8_fast, compile_model, compile_backend, compile_mode, compile_dynamic, compile_fullgraph ]) for videos, status, progress in process_hunyuani2v_video(*hunyuani2v_args): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress yield all_videos, "Batch complete", "" def variance_of_laplacian(image): """ Compute the variance of the Laplacian of the image. Higher variance indicates a sharper image. """ return cv2.Laplacian(image, cv2.CV_64F).var() def extract_sharpest_frame(video_path, frames_to_check=30): """ Extract the sharpest frame from the last N frames of the video. Args: video_path (str): Path to the video file frames_to_check (int): Number of frames from the end to check Returns: tuple: (temp_image_path, frame_number, sharpness_score) """ print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") print(f"Input video path: {video_path}") if not video_path or not os.path.exists(video_path): print("❌ Error: Video file does not exist") return None, None, None try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("❌ Error: Failed to open video file") return None, None, None total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") if total_frames < 1: print("❌ Error: Video contains 0 frames") return None, None, None # Determine how many frames to check (the last N frames) if frames_to_check > total_frames: frames_to_check = total_frames start_frame = 0 else: start_frame = total_frames - frames_to_check print(f"Checking frames {start_frame} to {total_frames-1}") # Find the sharpest frame sharpest_frame = None max_sharpness = -1 sharpest_frame_number = -1 # Set starting position cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) # Process frames with a progress bar with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: frame_idx = start_frame while frame_idx < total_frames: ret, frame = cap.read() if not ret: break # Convert to grayscale and calculate sharpness gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) sharpness = variance_of_laplacian(gray) # Update if this is the sharpest frame so far if sharpness > max_sharpness: max_sharpness = sharpness sharpest_frame = frame.copy() sharpest_frame_number = frame_idx frame_idx += 1 pbar.update(1) cap.release() if sharpest_frame is None: print("❌ Error: Failed to find a sharp frame") return None, None, None # Prepare output path temp_dir = os.path.abspath("temp_frames") os.makedirs(temp_dir, exist_ok=True) temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") print(f"Saving frame to: {temp_path}") # Write and verify if not cv2.imwrite(temp_path, sharpest_frame): print("❌ Error: Failed to write frame to file") return None, None, None if not os.path.exists(temp_path): print("❌ Error: Output file not created") return None, None, None # Calculate frame time in seconds frame_time = sharpest_frame_number / fps print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") return temp_path, sharpest_frame_number, max_sharpness except Exception as e: print(f"❌ Unexpected error: {str(e)}") return None, None, None finally: if 'cap' in locals(): cap.release() def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): """ Trim video up to the specified frame and save as a new video. Args: video_path (str): Path to the video file frame_number (int): Frame number to trim to output_dir (str): Directory to save the trimmed video Returns: str: Path to the trimmed video file """ print(f"\n=== Trimming video to frame {frame_number} ===") if not video_path or not os.path.exists(video_path): print("❌ Error: Video file does not exist") return None try: # Get video information cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("❌ Error: Failed to open video file") return None fps = cap.get(cv2.CAP_PROP_FPS) cap.release() # Calculate time in seconds time_seconds = frame_number / fps # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) # Generate output filename timestamp = f"{int(time_seconds)}s" base_name = Path(video_path).stem output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") # Use ffmpeg to trim the video ( ffmpeg .input(video_path) .output(output_file, to=time_seconds, c="copy") .global_args('-y') # Overwrite output files .run(quiet=True) ) if not os.path.exists(output_file): print("❌ Error: Failed to create trimmed video") return None print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") return output_file except Exception as e: print(f"❌ Error trimming video: {str(e)}") return None def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): """ Extract the sharpest frame from the last N frames of the selected video Args: gallery: Gradio gallery component with videos selected_idx: Index of the selected video frames_to_check: Number of frames from the end to check Returns: tuple: (image_path, video_path, frame_number, sharpness) """ if gallery is None or not gallery: return None, None, None, "No videos in gallery" if selected_idx is None and len(gallery) == 1: selected_idx = 0 if selected_idx is None or selected_idx >= len(gallery): return None, None, None, "No video selected" # Get the video path item = gallery[selected_idx] if isinstance(item, tuple): video_path = item[0] elif isinstance(item, dict): video_path = item.get('name') or item.get('data') else: video_path = str(item) # Extract the sharpest frame image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) if image_path is None: return None, None, None, "Failed to extract sharpest frame" return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): """ Trim the video to the specified frame and prepare for extension. Args: video_path: Path to the video file frame_number: Frame number to trim to save_path: Directory to save the trimmed video Returns: tuple: (trimmed_video_path, status_message) """ if not video_path or not os.path.exists(video_path): return None, "No video selected or video file does not exist" if frame_number is None: return None, "No frame number provided, please extract sharpest frame first" # Trim the video trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) if trimmed_video is None: return None, "Failed to trim video" return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" def send_last_frame_handler(gallery, selected_idx): """Handle sending last frame to input with better error handling""" if gallery is None or not gallery: return None, None if selected_idx is None and len(gallery) == 1: selected_idx = 0 if selected_idx is None or selected_idx >= len(gallery): return None, None # Get the frame and video path frame = handle_last_frame_transfer(gallery, selected_idx) video_path = None if selected_idx < len(gallery): item = gallery[selected_idx] video_path = parse_video_path(item) return frame, video_path def extract_last_frame(video_path: str) -> Optional[str]: """Extract last frame from video and return temporary image path with error handling""" print(f"\n=== Starting frame extraction ===") print(f"Input video path: {video_path}") if not video_path or not os.path.exists(video_path): print("❌ Error: Video file does not exist") return None try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("❌ Error: Failed to open video file") return None total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) print(f"Total frames detected: {total_frames}") if total_frames < 1: print("❌ Error: Video contains 0 frames") return None # Extract last frame cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) success, frame = cap.read() if not success or frame is None: print("❌ Error: Failed to read last frame") return None # Prepare output path temp_dir = os.path.abspath("temp_frames") os.makedirs(temp_dir, exist_ok=True) temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") print(f"Saving frame to: {temp_path}") # Write and verify if not cv2.imwrite(temp_path, frame): print("❌ Error: Failed to write frame to file") return None if not os.path.exists(temp_path): print("❌ Error: Output file not created") return None print("✅ Frame extraction successful") return temp_path except Exception as e: print(f"❌ Unexpected error: {str(e)}") return None finally: if 'cap' in locals(): cap.release() def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: """Improved frame transfer with video input validation""" try: if gallery is None or not gallery: raise ValueError("No videos generated yet") if selected_idx is None: # Auto-select last generated video if batch_size=1 if len(gallery) == 1: selected_idx = 0 else: raise ValueError("Please select a video first") if selected_idx >= len(gallery): raise ValueError("Invalid selection index") item = gallery[selected_idx] # Video file existence check video_path = parse_video_path(item) if not os.path.exists(video_path): raise FileNotFoundError(f"Video file missing: {video_path}") return extract_last_frame(video_path) except Exception as e: print(f"Frame transfer failed: {str(e)}") return None def parse_video_path(item) -> str: """Parse different gallery item formats""" if isinstance(item, tuple): return item[0] elif isinstance(item, dict): return item.get('name') or item.get('data') return str(item) def get_random_image_from_folder(folder_path): """Get a random image from the specified folder""" if not os.path.isdir(folder_path): return None, f"Error: {folder_path} is not a valid directory" # Get all image files in the folder image_files = [] for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): image_files.extend(glob.glob(os.path.join(folder_path, ext))) for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): image_files.extend(glob.glob(os.path.join(folder_path, ext))) if not image_files: return None, f"Error: No image files found in {folder_path}" # Select a random image random_image = random.choice(image_files) return random_image, f"Selected: {os.path.basename(random_image)}" def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): """Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" try: img = Image.open(image_path) width, height = img.size # Calculate aspect ratio aspect_ratio = width / height # Calculate new dimensions while maintaining aspect ratio if width > height: new_width = min(max_width, width) new_height = int(new_width / aspect_ratio) else: new_height = min(max_height, height) new_width = int(new_height * aspect_ratio) # Make dimensions divisible by 16 new_width = math.floor(new_width / 16) * 16 new_height = math.floor(new_height / 16) * 16 # Ensure minimum size new_width = max(16, new_width) new_height = max(16, new_height) # Resize image resized_img = img.resize((new_width, new_height), Image.LANCZOS) # Save to temporary file temp_path = f"temp_resized_{os.path.basename(image_path)}" resized_img.save(temp_path) return temp_path, (new_width, new_height) except Exception as e: return None, f"Error: {str(e)}" # Function to process a batch of images from a folder def batch_handler( use_random, prompt, negative_prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, embedded_cfg_scale, batch_size, input_folder_path, dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, lora_folder, *lora_params ): """Handle both folder-based batch processing and regular batch processing""" global stop_event # Check if this is a SkyReels model that needs special handling is_skyreels = "skyreels" in model.lower() is_skyreels_i2v = is_skyreels and "i2v" in model.lower() if use_random: # Random image from folder mode stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text for i in range(batch_size): if stop_event.is_set(): break batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Get random image from folder random_image, status = get_random_image_from_folder(input_folder_path) if random_image is None: yield all_videos, f"Error in batch {i+1}: {status}", "" continue # Resize image resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) if resized_image is None: yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" continue # If we have dimensions, update them local_width, local_height = width, height if isinstance(size_info, tuple): local_width, local_height = size_info progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" else: progress_text = f"Using image: {os.path.basename(random_image)}" yield all_videos.copy(), batch_text, progress_text # Calculate seed for this batch item current_seed = seed if seed == -1: current_seed = random.randint(0, 2**32 - 1) elif batch_size > 1: current_seed = seed + i # Process the image # For SkyReels models, we need to create a command with dit_in_channels=32 if is_skyreels_i2v: env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model # Extract parameters from lora_params num_lora_weights = 4 lora_weights = lora_params[:num_lora_weights] lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] cmd = [ sys.executable, "hv_generate_video.py", "--dit", model_path, "--vae", vae, "--text_encoder1", te1, "--text_encoder2", te2, "--prompt", prompt, "--video_size", str(local_height), str(local_width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--embedded_cfg_scale", str(embedded_cfg_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--fp8_llm", "--vae_chunk_size", "32", "--vae_spatial_tile_sample_min_size", "128", "--dit_in_channels", "32", # This is crucial for SkyReels i2v "--image_path", resized_image # Pass the image directly ] if use_fp8: cmd.append("--fp8") if split_uncond: cmd.append("--split_uncond") if use_split_attn: cmd.append("--split_attn") if exclude_single_blocks: cmd.append("--exclude_single_blocks") if negative_prompt: cmd.extend(["--negative_prompt", negative_prompt]) if guidance_scale is not None: cmd.extend(["--guidance_scale", str(guidance_scale)]) # Add LoRA weights and multipliers if provided valid_loras = [] for weight, mult in zip(lora_weights, lora_multipliers): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), mult)) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] cmd.extend(["--lora_weight"] + weights) cmd.extend(["--lora_multiplier"] + multipliers) print(f"Running command: {' '.join(cmd)}") # Run the process p = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) while True: if stop_event.is_set(): p.terminate() p.wait() yield all_videos, "Generation stopped by user.", "" return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() p.stdout.close() p.wait() # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos_files = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) all_videos.append((str(video_path), f"Seed: {current_seed}")) else: # For non-SkyReels models, use the regular process_single_video function num_lora_weights = 4 lora_weights = lora_params[:num_lora_weights] lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] single_video_args = [ prompt, local_width, local_height, 1, video_length, fps, infer_steps, current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder ] single_video_args.extend(lora_weights) single_video_args.extend(lora_multipliers) single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) for videos, status, progress in process_single_video(*single_video_args): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress # Clean up temporary file try: if os.path.exists(resized_image): os.remove(resized_image) except: pass # Clear CUDA cache between generations clear_cuda_cache() time.sleep(0.5) yield all_videos, "Batch complete", "" else: # Regular image input - this is the part we need to fix # When a SkyReels I2V model is used, we need to use the direct command approach # with dit_in_channels=32 explicitly specified, just like in the folder processing branch if is_skyreels_i2v: stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Extract lora parameters num_lora_weights = 4 lora_weights = lora_params[:num_lora_weights] lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] # Print extra_args for debugging print(f"Extra args: {extra_args}") image_path = None if len(extra_args) > 0 and extra_args[0] is not None: image_path = extra_args[0] print(f"Image path found in extra_args[0]: {image_path}") if not image_path: print("No image path found in extra_args[0]") print(f"Full lora_params: {lora_params}") yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" return for i in range(batch_size): if stop_event.is_set(): yield all_videos, "Generation stopped by user", "" return # Calculate seed for this batch item current_seed = seed if seed == -1: current_seed = random.randint(0, 2**32 - 1) elif batch_size > 1: current_seed = seed + i batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Set up environment env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model # Build the command with dit_in_channels=32 cmd = [ sys.executable, "hv_generate_video.py", "--dit", model_path, "--vae", vae, "--text_encoder1", te1, "--text_encoder2", te2, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--embedded_cfg_scale", str(embedded_cfg_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--fp8_llm", "--vae_chunk_size", "32", "--vae_spatial_tile_sample_min_size", "128", "--dit_in_channels", "32", # This is crucial for SkyReels i2v "--image_path", image_path ] if use_fp8: cmd.append("--fp8") if split_uncond: cmd.append("--split_uncond") if use_split_attn: cmd.append("--split_attn") if exclude_single_blocks: cmd.append("--exclude_single_blocks") if negative_prompt: cmd.extend(["--negative_prompt", negative_prompt]) if guidance_scale is not None: cmd.extend(["--guidance_scale", str(guidance_scale)]) # Add LoRA weights and multipliers if provided valid_loras = [] for weight, mult in zip(lora_weights, lora_multipliers): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), mult)) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] cmd.extend(["--lora_weight"] + weights) cmd.extend(["--lora_multiplier"] + multipliers) print(f"Running command: {' '.join(cmd)}") # Run the process p = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) while True: if stop_event.is_set(): p.terminate() p.wait() yield all_videos, "Generation stopped by user.", "" return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos_files = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) all_videos.append((str(video_path), f"Seed: {current_seed}")) # Clear CUDA cache between generations clear_cuda_cache() time.sleep(0.5) yield all_videos, "Batch complete", "" else: # For regular non-SkyReels models, use the original process_batch function regular_args = [ prompt, width, height, batch_size, video_length, fps, infer_steps, seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder ] yield from process_batch(*(regular_args + list(lora_params))) def get_dit_models(dit_folder: str) -> List[str]: """Get list of available DiT models in the specified folder""" if not os.path.exists(dit_folder): return ["mp_rank_00_model_states.pt"] models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] models.sort(key=str.lower) return models if models else ["mp_rank_00_model_states.pt"] def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: """Update both DiT and LoRA dropdowns""" # Get model lists dit_models = get_dit_models(dit_folder) lora_choices = get_lora_options(lora_folder) # Current values processing dit_value = current_values[0] if dit_value not in dit_models: dit_value = dit_models[0] if dit_models else None weights = current_values[1:5] multipliers = current_values[5:9] results = [gr.update(choices=dit_models, value=dit_value)] # Add LoRA updates for i in range(4): weight = weights[i] if i < len(weights) else "None" multiplier = multipliers[i] if i < len(multipliers) else 1.0 if weight not in lora_choices: weight = "None" results.extend([ gr.update(choices=lora_choices, value=weight), gr.update(value=multiplier) ]) return results def extract_video_metadata(video_path: str) -> Dict: """Extract metadata from video file using ffprobe.""" cmd = [ 'ffprobe', '-v', 'quiet', '-print_format', 'json', '-show_format', video_path ] try: result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) metadata = json.loads(result.stdout.decode('utf-8')) if 'format' in metadata and 'tags' in metadata['format']: comment = metadata['format']['tags'].get('comment', '{}') return json.loads(comment) return {} except Exception as e: print(f"Metadata extraction failed: {str(e)}") return {} def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: """Map metadata parameters to Gradio components for different tabs""" mapping = { 'common': { 'prompt': ('prompt', 'v2v_prompt'), 'width': ('width', 'v2v_width'), 'height': ('height', 'v2v_height'), 'batch_size': ('batch_size', 'v2v_batch_size'), 'video_length': ('video_length', 'v2v_video_length'), 'fps': ('fps', 'v2v_fps'), 'infer_steps': ('infer_steps', 'v2v_infer_steps'), 'seed': ('seed', 'v2v_seed'), 'model': ('model', 'v2v_model'), 'vae': ('vae', 'v2v_vae'), 'te1': ('te1', 'v2v_te1'), 'te2': ('te2', 'v2v_te2'), 'save_path': ('save_path', 'v2v_save_path'), 'flow_shift': ('flow_shift', 'v2v_flow_shift'), 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), 'output_type': ('output_type', 'v2v_output_type'), 'attn_mode': ('attn_mode', 'v2v_attn_mode'), 'block_swap': ('block_swap', 'v2v_block_swap'), 'negative_prompt': ('i2v_negative_prompt', 'v2v_negative_prompt'), 'clip_vision_path': ('i2v_clip_vision_path', None), 'i2v_stability': ('i2v_stability', None), 'fp8_fast': ('i2v_fp8_fast', None) }, 'lora': { 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] } } results = {} for param, value in metadata.items(): # Handle common parameters if param in mapping['common']: target = mapping['common'][param][0 if target_tab == 't2v' else 1] results[target] = value # Handle LoRA parameters if param == 'lora_weights': for i, weight in enumerate(value[:4]): target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] results[target] = weight if param == 'lora_multipliers': for i, mult in enumerate(value[:4]): target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] results[target] = float(mult) return results def add_metadata_to_video(video_path: str, parameters: dict) -> None: """Add generation parameters to video metadata using ffmpeg.""" import json import subprocess # Convert parameters to JSON string params_json = json.dumps(parameters, indent=2) # Temporary output path temp_path = video_path.replace(".mp4", "_temp.mp4") # FFmpeg command to add metadata without re-encoding cmd = [ 'ffmpeg', '-i', video_path, '-metadata', f'comment={params_json}', '-codec', 'copy', temp_path ] try: # Execute FFmpeg command subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Replace original file with the metadata-enhanced version os.replace(temp_path, video_path) except subprocess.CalledProcessError as e: print(f"Failed to add metadata: {e.stderr.decode()}") if os.path.exists(temp_path): os.remove(temp_path) except Exception as e: print(f"Error: {str(e)}") def count_prompt_tokens(prompt: str) -> int: enc = tiktoken.get_encoding("cl100k_base") tokens = enc.encode(prompt) return len(tokens) def get_lora_options(lora_folder: str = "lora") -> List[str]: if not os.path.exists(lora_folder): return ["None"] lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] lora_files.sort(key=str.lower) return ["None"] + lora_files def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: new_choices = get_lora_options(lora_folder) weights = current_values[:4] multipliers = current_values[4:8] results = [] for i in range(4): weight = weights[i] if i < len(weights) else "None" multiplier = multipliers[i] if i < len(multipliers) else 1.0 if weight not in new_choices: weight = "None" results.extend([ gr.update(choices=new_choices, value=weight), gr.update(value=multiplier) ]) return results def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: """Transfer selected video and prompt to Video2Video tab""" if not gallery or evt.index >= len(gallery): return None, "", selected_index.value selected_item = gallery[evt.index] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] # Update the selected index selected_index.value = evt.index return str(video_path), prompt, evt.index def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: """Send the currently selected video to V2V tab""" if not gallery or selected_index.value is None or selected_index.value >= len(gallery): return None, "" selected_item = gallery[selected_index.value] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] return str(video_path), prompt def clear_cuda_cache(): """Clear CUDA cache if available""" import torch if torch.cuda.is_available(): torch.cuda.empty_cache() # Optional: synchronize to ensure cache is cleared torch.cuda.synchronize() def wanx_batch_handler( use_random, prompt, negative_prompt, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, seed, batch_size, input_folder_path, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, *lora_params ): """Handle both folder-based batch processing and regular processing for WanX""" global stop_event if use_random: # Random image from folder mode stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Ensure batch_size is treated as an integer batch_size = int(batch_size) # Process each item in the batch separately for i in range(batch_size): if stop_event.is_set(): yield all_videos, "Generation stopped by user", "" return batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Get random image from folder random_image, status = get_random_image_from_folder(input_folder_path) if random_image is None: yield all_videos, f"Error in batch {i+1}: {status}", "" continue # Resize image resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) if resized_image is None: yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" continue # Use the dimensions returned from the resize function local_width, local_height = width, height # Default fallback if isinstance(size_info, tuple): local_width, local_height = size_info progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height} (maintaining aspect ratio)" else: progress_text = f"Using image: {os.path.basename(random_image)}" yield all_videos.copy(), batch_text, progress_text # Calculate seed for this batch item current_seed = seed if seed == -1: current_seed = random.randint(0, 2**32 - 1) elif batch_size > 1: current_seed = seed + i # Extract LoRA weights and multipliers num_lora_weights = 4 lora_weights = lora_params[:num_lora_weights] lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] # Generate video for this image - one at a time for videos, status, progress in wanx_generate_video( prompt, negative_prompt, resized_image, local_width, local_height, video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, *lora_weights, *lora_multipliers ): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress # Clean up temporary file try: if os.path.exists(resized_image): os.remove(resized_image) except: pass # Clear CUDA cache between generations clear_cuda_cache() time.sleep(0.5) yield all_videos, "Batch complete", "" else: # For non-random mode, if batch_size > 1, we need to process multiple times # with the same input image but different seeds if int(batch_size) > 1: stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Extract LoRA weights and multipliers and input image num_lora_weights = 4 lora_weights = lora_params[:num_lora_weights] lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None # Process each batch item for i in range(int(batch_size)): if stop_event.is_set(): yield all_videos, "Generation stopped by user", "" return # Calculate seed for this batch item current_seed = seed if seed == -1: current_seed = random.randint(0, 2**32 - 1) elif batch_size > 1: current_seed = seed + i batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Generate a single video with the current seed for videos, status, progress in wanx_generate_video( prompt, negative_prompt, input_image, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, *lora_weights, *lora_multipliers ): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress # Clear CUDA cache between generations clear_cuda_cache() time.sleep(0.5) yield all_videos, "Batch complete", "" else: # Single image, single generation - use existing function num_lora_weights = 4 lora_weights = lora_params[:num_lora_weights] lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None yield from wanx_generate_video( prompt, negative_prompt, input_image, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, *lora_weights, *lora_multipliers ) def process_single_video( prompt: str, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, dit_folder: str, model: str, vae: str, te1: str, te2: str, save_path: str, flow_shift: float, cfg_scale: float, output_type: str, attn_mode: str, block_swap: int, exclude_single_blocks: bool, use_split_attn: bool, lora_folder: str, lora1: str = "", lora2: str = "", lora3: str = "", lora4: str = "", lora1_multiplier: float = 1.0, lora2_multiplier: float = 1.0, lora3_multiplier: float = 1.0, lora4_multiplier: float = 1.0, video_path: Optional[str] = None, image_path: Optional[str] = None, strength: Optional[float] = None, negative_prompt: Optional[str] = None, embedded_cfg_scale: Optional[float] = None, split_uncond: Optional[bool] = None, guidance_scale: Optional[float] = None, use_fp8: bool = True ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate a single video with the given parameters""" global stop_event if stop_event.is_set(): yield [], "", "" return # Determine if this is a SkyReels model and what type is_skyreels = "skyreels" in model.lower() is_skyreels_i2v = is_skyreels and "i2v" in model.lower() is_skyreels_t2v = is_skyreels and "t2v" in model.lower() if is_skyreels: # Force certain parameters for SkyReels if negative_prompt is None: negative_prompt = "" if embedded_cfg_scale is None: embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels if split_uncond is None: split_uncond = True if guidance_scale is None: guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided # Determine the input channels based on model type if is_skyreels_i2v: dit_in_channels = 32 # SkyReels I2V uses 32 channels else: dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) else: dit_in_channels = 16 # Regular Hunyuan models use 16 channels embedded_cfg_scale = cfg_scale if os.path.isabs(model): model_path = model else: model_path = os.path.normpath(os.path.join(dit_folder, model)) env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" env["BATCH_RUN_ID"] = f"{time.time()}" if seed == -1: current_seed = random.randint(0, 2**32 - 1) else: batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) if batch_size > 1: # Only modify seed for batch generation current_seed = (seed + batch_id * 100003) % (2**32) else: current_seed = seed clear_cuda_cache() command = [ sys.executable, "hv_generate_video.py", "--dit", model_path, "--vae", vae, "--text_encoder1", te1, "--text_encoder2", te2, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--embedded_cfg_scale", str(cfg_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--fp8_llm", "--vae_chunk_size", "32", "--vae_spatial_tile_sample_min_size", "128" ] if use_fp8: command.append("--fp8") # Add negative prompt and embedded cfg scale for SkyReels if is_skyreels: command.extend(["--dit_in_channels", str(dit_in_channels)]) command.extend(["--guidance_scale", str(guidance_scale)]) if negative_prompt: command.extend(["--negative_prompt", negative_prompt]) if split_uncond: command.append("--split_uncond") # Add LoRA weights and multipliers if provided valid_loras = [] for weight, mult in zip([lora1, lora2, lora3, lora4], [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), mult)) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] command.extend(["--lora_weight"] + weights) command.extend(["--lora_multiplier"] + multipliers) if exclude_single_blocks: command.append("--exclude_single_blocks") if use_split_attn: command.append("--split_attn") # Handle input paths if video_path: command.extend(["--video_path", video_path]) if strength is not None: command.extend(["--strength", str(strength)]) elif image_path: command.extend(["--image_path", image_path]) # Only add strength parameter for non-SkyReels I2V models # SkyReels I2V doesn't use strength parameter for image-to-video generation if strength is not None and not is_skyreels_i2v: command.extend(["--strength", str(strength)]) print(f"{command}") p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) videos = [] while True: if stop_event.is_set(): p.terminate() p.wait() yield [], "", "Generation stopped by user." return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() clear_cuda_cache() time.sleep(0.5) # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) # Collect parameters for metadata parameters = { "prompt": prompt, "width": width, "height": height, "video_length": video_length, "fps": fps, "infer_steps": infer_steps, "seed": current_seed, "model": model, "vae": vae, "te1": te1, "te2": te2, "save_path": save_path, "flow_shift": flow_shift, "cfg_scale": cfg_scale, "output_type": output_type, "attn_mode": attn_mode, "block_swap": block_swap, "lora_weights": [lora1, lora2, lora3, lora4], "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], "input_video": video_path if video_path else None, "input_image": image_path if image_path else None, "strength": strength, "negative_prompt": negative_prompt if is_skyreels else None, "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None } add_metadata_to_video(video_path, parameters) videos.append((str(video_path), f"Seed: {current_seed}")) yield videos, f"Completed (seed: {current_seed})", "" # The issue is in the process_batch function, in the section that handles different input types # Here's the corrected version of that section: def process_batch( prompt: str, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, dit_folder: str, model: str, vae: str, te1: str, te2: str, save_path: str, flow_shift: float, cfg_scale: float, output_type: str, attn_mode: str, block_swap: int, exclude_single_blocks: bool, use_split_attn: bool, lora_folder: str, *args ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Process a batch of videos using Gradio's queue""" global stop_event stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Extract additional arguments num_lora_weights = 4 lora_weights = args[:num_lora_weights] lora_multipliers = args[num_lora_weights:num_lora_weights*2] extra_args = args[num_lora_weights*2:] # Determine if this is a SkyReels model and what type is_skyreels = "skyreels" in model.lower() is_skyreels_i2v = is_skyreels and "i2v" in model.lower() is_skyreels_t2v = is_skyreels and "t2v" in model.lower() # Handle input paths and additional parameters input_path = extra_args[0] if extra_args else None strength = float(extra_args[1]) if len(extra_args) > 1 else None # Get use_fp8 flag (it should be the last parameter) use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True # Get SkyReels specific parameters if applicable if is_skyreels: # Always set embedded_cfg_scale to 1.0 for SkyReels models embedded_cfg_scale = 1.0 negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" # Use cfg_scale for guidance_scale parameter guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale split_uncond = True if len(extra_args) > 4 and extra_args[4] else False else: negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None guidance_scale = cfg_scale embedded_cfg_scale = cfg_scale split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None for i in range(batch_size): if stop_event.is_set(): break batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Handle different input types video_path = None image_path = None if input_path: # Check if it's an image file (common image extensions) is_image = False lower_path = input_path.lower() image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') is_image = any(lower_path.endswith(ext) for ext in image_extensions) # Only use image_path for SkyReels I2V models and actual image files if is_skyreels_i2v and is_image: image_path = input_path else: video_path = input_path # Prepare arguments for process_single_video single_video_args = [ prompt, width, height, batch_size, video_length, fps, infer_steps, seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder ] single_video_args.extend(lora_weights) single_video_args.extend(lora_multipliers) single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) for videos, status, progress in process_single_video(*single_video_args): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress yield all_videos, "Batch complete", "" def update_wanx_image_dimensions(image): """Update dimensions from uploaded image""" if image is None: return "", gr.update(value=832), gr.update(value=480) img = Image.open(image) w, h = img.size w = (w // 32) * 32 h = (h // 32) * 32 return f"{w}x{h}", w, h def calculate_wanx_width(height, original_dims): """Calculate width based on height maintaining aspect ratio""" if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 32) * 32 return gr.update(value=new_width) def calculate_wanx_height(width, original_dims): """Calculate height based on width maintaining aspect ratio""" if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 32) * 32 return gr.update(value=new_height) def update_wanx_from_scale(scale, original_dims): """Update dimensions based on scale percentage""" if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 32) * 32 new_h = math.floor((orig_h * scale / 100) / 32) * 32 return gr.update(value=new_w), gr.update(value=new_h) def recommend_wanx_flow_shift(width, height): """Get recommended flow shift value based on dimensions""" recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 return gr.update(value=recommended_shift) def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: """Track selected index and video path when gallery item is clicked""" if gallery is None: return None, None if evt.index >= len(gallery): return None, None selected_item = gallery[evt.index] video_path = None # Extract the video path based on the item type if isinstance(selected_item, tuple): video_path = selected_item[0] elif isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) else: video_path = selected_item return evt.index, video_path def wanx_generate_video( prompt, negative_prompt, input_image, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, lora1="None", lora2="None", lora3="None", lora4="None", lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0 ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate video with WanX model (supports both i2v and t2v)""" global stop_event if stop_event.is_set(): yield [], "", "" return if seed == -1: current_seed = random.randint(0, 2**32 - 1) else: current_seed = seed # Check if we need input image (required for i2v, not for t2v) if "i2v" in task and not input_image: yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" return # Prepare environment env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") env["PYTHONIOENCODING"] = "utf-8" clear_cuda_cache() command = [ sys.executable, "wan_generate_video.py", "--task", task, "--prompt", prompt, "--video_size", str(height), str(width), "--video_length", str(video_length), "--fps", str(fps), "--infer_steps", str(infer_steps), "--save_path", save_path, "--seed", str(current_seed), "--flow_shift", str(flow_shift), "--guidance_scale", str(guidance_scale), "--output_type", output_type, "--attn_mode", attn_mode, "--blocks_to_swap", str(block_swap), "--dit", dit_path, "--vae", vae_path, "--t5", t5_path, "--sample_solver", sample_solver ] # Add image path only for i2v task and if input image is provided if "i2v" in task and input_image: command.extend(["--image_path", input_image]) command.extend(["--clip", clip_path]) # CLIP is only needed for i2v if negative_prompt: command.extend(["--negative_prompt", negative_prompt]) if fp8: command.append("--fp8") if fp8_t5: command.append("--fp8_t5") if exclude_single_blocks: command.append("--exclude_single_blocks") # Add LoRA weights and multipliers if provided valid_loras = [] for weight, mult in zip([lora1, lora2, lora3, lora4], [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): if weight and weight != "None": valid_loras.append((os.path.join(lora_folder, weight), mult)) if valid_loras: weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] command.extend(["--lora_weight"] + weights) command.extend(["--lora_multiplier"] + multipliers) print(f"Running: {' '.join(command)}") p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 ) videos = [] while True: if stop_event.is_set(): p.terminate() p.wait() yield [], "", "Generation stopped by user." return line = p.stdout.readline() if not line: if p.poll() is not None: break continue print(line, end='') if '|' in line and '%' in line and '[' in line and ']' in line: yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() p.stdout.close() p.wait() clear_cuda_cache() time.sleep(0.5) # Collect generated video save_path_abs = os.path.abspath(save_path) if os.path.exists(save_path_abs): all_videos = sorted( [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), reverse=True ) matching_videos = [v for v in all_videos if f"_{current_seed}" in v] if matching_videos: video_path = os.path.join(save_path_abs, matching_videos[0]) # Collect parameters for metadata parameters = { "prompt": prompt, "width": width, "height": height, "video_length": video_length, "fps": fps, "infer_steps": infer_steps, "seed": current_seed, "task": task, "flow_shift": flow_shift, "guidance_scale": guidance_scale, "output_type": output_type, "attn_mode": attn_mode, "block_swap": block_swap, "input_image": input_image if "i2v" in task else None } add_metadata_to_video(video_path, parameters) videos.append((str(video_path), f"Seed: {current_seed}")) yield videos, f"Completed (seed: {current_seed})", "" def send_wanx_to_v2v( gallery: list, prompt: str, selected_index: int, width: int, height: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, guidance_scale: float, negative_prompt: str ) -> Tuple: """Send the selected WanX video to Video2Video tab""" if gallery is None or not gallery: return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) # If no selection made but we have videos, use the first one if selected_index is None and len(gallery) > 0: selected_index = 0 if selected_index is None or selected_index >= len(gallery): return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) selected_item = gallery[selected_index] # Handle different gallery item formats if isinstance(selected_item, tuple): video_path = selected_item[0] elif isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) else: video_path = selected_item # Clean up path for Video component if isinstance(video_path, tuple): video_path = video_path[0] # Make sure it's a string video_path = str(video_path) return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) def wanx_generate_video_batch( prompt, negative_prompt, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, lora1="None", lora2="None", lora3="None", lora4="None", lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0, batch_size=1, input_image=None # Make input_image optional and place it at the end ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: """Generate videos with WanX with support for batches""" global stop_event stop_event.clear() all_videos = [] progress_text = "Starting generation..." yield [], "Preparing...", progress_text # Process each item in the batch for i in range(batch_size): if stop_event.is_set(): yield all_videos, "Generation stopped by user", "" return # Calculate seed for this batch item current_seed = seed if seed == -1: current_seed = random.randint(0, 2**32 - 1) elif batch_size > 1: current_seed = seed + i batch_text = f"Generating video {i + 1} of {batch_size}" yield all_videos.copy(), batch_text, progress_text # Generate a single video using the existing function for videos, status, progress in wanx_generate_video( prompt, negative_prompt, input_image, width, height, video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, task, dit_path, vae_path, t5_path, clip_path, save_path, output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, fp8, fp8_t5, lora_folder, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier ): if videos: all_videos.extend(videos) yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress yield all_videos, "Batch complete", "" def update_wanx_t2v_dimensions(size): """Update width and height based on selected size""" width, height = map(int, size.split('*')) return gr.update(value=width), gr.update(value=height) def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: """Track selected index when gallery item is clicked""" return evt.index def send_wanx_t2v_to_v2v( gallery, prompt, selected_index, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt ) -> Tuple: """Send the selected WanX T2V video to Video2Video tab""" if not gallery or selected_index is None or selected_index >= len(gallery): return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) selected_item = gallery[selected_index] if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item if isinstance(video_path, tuple): video_path = video_path[0] return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt) def prepare_for_batch_extension(input_img, base_video, batch_size): """Prepare inputs for batch video extension""" if input_img is None: return None, None, batch_size, "No input image found", "" if base_video is None: return input_img, None, batch_size, "No base video selected for extension", "" return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): """Concatenate multiple generated videos with the base video""" if not base_video_path: return [], "No base video provided" if not generated_videos or len(generated_videos) == 0: return [], "No new videos generated" # Create output directory if it doesn't exist os.makedirs(save_path, exist_ok=True) # Track all extended videos extended_videos = [] # For each generated video, create an extended version for i, video_item in enumerate(generated_videos): try: # Extract video path from gallery item if isinstance(video_item, tuple): new_video_path = video_item[0] seed_info = video_item[1] if len(video_item) > 1 else "" elif isinstance(video_item, dict): new_video_path = video_item.get("name", video_item.get("data", None)) seed_info = "" else: new_video_path = video_item seed_info = "" if not new_video_path or not os.path.exists(new_video_path): print(f"Skipping missing video: {new_video_path}") continue # Create unique output filename timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") # Extract seed from seed_info if available seed_match = re.search(r"Seed: (\d+)", seed_info) seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" output_path = os.path.join(save_path, output_filename) # Create a temporary file list for ffmpeg list_file = os.path.join(save_path, f"temp_list_{i}.txt") with open(list_file, "w") as f: f.write(f"file '{os.path.abspath(base_video_path)}'\n") f.write(f"file '{os.path.abspath(new_video_path)}'\n") # Run ffmpeg concatenation command = [ "ffmpeg", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", output_path ] subprocess.run(command, check=True, capture_output=True) # Clean up temporary file if os.path.exists(list_file): os.remove(list_file) # Add to extended videos list if successful if os.path.exists(output_path): seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" extended_videos.append((output_path, seed_display)) except Exception as e: print(f"Error processing video {i}: {str(e)}") if not extended_videos: return [], "Failed to create any extended videos" return extended_videos, f"Successfully created {len(extended_videos)} extended videos" def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: """Combine generated video with base video and update gallery""" if not base_video_path: return current_gallery, "Extend failed: No base video provided" if not new_videos: return current_gallery, "Extend failed: No new video generated" # Ensure save path exists os.makedirs(save_path, exist_ok=True) # Get the first video from new_videos (gallery item) new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] # Create a unique output filename timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" output_path = str(Path(save_path) / output_filename) try: # Concatenate the videos using ffmpeg ( ffmpeg .input(base_video_path) .concat( ffmpeg.input(new_video_path) ) .output(output_path) .run(overwrite_output=True, quiet=True) ) # Create a new gallery entry with the combined video updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] return updated_gallery, f"Successfully extended video to {Path(output_path).name}" except Exception as e: print(f"Error extending video: {str(e)}") return current_gallery, f"Failed to extend video: {str(e)}" # UI setup with gr.Blocks( theme=themes.Default( primary_hue=colors.Color( name="custom", c50="#E6F0FF", c100="#CCE0FF", c200="#99C1FF", c300="#66A3FF", c400="#3384FF", c500="#0060df", # This is your main color c600="#0052C2", c700="#003D91", c800="#002961", c900="#001430", c950="#000A18" ) ), css=""" .gallery-item:first-child { border: 2px solid #4CAF50 !important; } .gallery-item:first-child:hover { border-color: #45a049 !important; } .green-btn { background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; color: white !important; border: none !important; } .green-btn:hover { background: linear-gradient(to bottom right, #27ae60, #219651) !important; } .refresh-btn { max-width: 40px !important; min-width: 40px !important; height: 40px !important; border-radius: 50% !important; padding: 0 !important; display: flex !important; align-items: center !important; justify-content: center !important; } """, ) as demo: # Add state for tracking selected video indices in both tabs selected_index = gr.State(value=None) # For Text to Video v2v_selected_index = gr.State(value=None) # For Video to Video params_state = gr.State() #New addition i2v_selected_index = gr.State(value=None) skyreels_selected_index = gr.State(value=None) wanx_i2v_selected_index = gr.State(value=None) extended_videos = gr.State(value=[]) wanx_base_video = gr.State(value=None) wanx_sharpest_frame_number = gr.State(value=None) wanx_sharpest_frame_path = gr.State(value=None) wanx_trimmed_video_path = gr.State(value=None) demo.load(None, None, None, js=""" () => { document.title = 'H1111'; function updateTitle(text) { if (text && text.trim()) { const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); if (progressMatch) { const percentage = progressMatch[1]; const timeRemaining = progressMatch[2]; document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; } } } setTimeout(() => { const progressElements = document.querySelectorAll('textarea.scroll-hide'); progressElements.forEach(element => { if (element) { new MutationObserver(() => { updateTitle(element.value); }).observe(element, { attributes: true, childList: true, characterData: true }); } }); }, 1000); } """) with gr.Tabs() as tabs: # Text to Video Tab with gr.Tab(id=1, label="Hunyuan-t2v"): with gr.Row(): with gr.Column(scale=4): prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) with gr.Column(scale=1): token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): generate_btn = gr.Button("Generate Video", elem_classes="green-btn") stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") with gr.Column(): with gr.Row(): video_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") with gr.Row(): refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") lora_weights = [] lora_multipliers = [] for i in range(4): with gr.Column(): lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) seed = gr.Number(label="Seed (use -1 for random)", value=-1) dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") model = gr.Dropdown( label="DiT Model", choices=get_dit_models("hunyuan"), value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): lora_folder = gr.Textbox(label="LoRA Folder", value="lora") output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) #Image to Video Tab with gr.Tab(label="Hunyuan-i2v") as i2v_tab: with gr.Row(): with gr.Column(scale=4): i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) i2v_negative_prompt = gr.Textbox(label="Negative Prompt", value="", lines=2, info="Negative prompt") with gr.Column(scale=1): i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") i2v_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): i2v_input = gr.Image(label="Input Image", type="filepath") i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") # Scale slider as percentage scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and height inputs with gr.Row(): width = gr.Number(label="New Width", value=544, step=16) calc_height_btn = gr.Button("→") calc_width_btn = gr.Button("←") height = gr.Number(label="New Height", value=544, step=16) i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) with gr.Column(): i2v_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") # Add LoRA section for Image2Video i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") i2v_lora_weights = [] i2v_lora_multipliers = [] for i in range(4): with gr.Column(): i2v_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) i2v_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") i2v_model = gr.Dropdown( label="DiT Model", choices=get_dit_models("hunyuan"), value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava-llama-3-8b-text-encoder-tokenizer") i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") i2v_clip_vision_path = gr.Textbox(label="CLIP Vision Path", value="hunyuan/clip-vit-large-patch14", info="Path to CLIP vision model for HunyuanI2V") i2v_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) i2v_split_uncond = gr.Checkbox(label="Split Unconditional", value=True, visible=True) with gr.Row(): i2v_stability = gr.Checkbox(label="I2V Stability", value=False, info="Enable stability mode for HunyuanI2V") i2v_fp8_fast = gr.Checkbox(label="FP8 Fast", value=False, info="Enable fast FP8 arithmetic (RTX 4XXX+)") i2v_compile = gr.Checkbox(label="Compile Model", value=False, info="Enable torch.compile for potentially faster generation") i2v_compile_backend = gr.Dropdown(label="Compile Backend", choices=["inductor", "cudagraphs", "onnxrt", "nvfuser"], value="inductor", info="Torch compile backend") i2v_compile_mode = gr.Dropdown(label="Compile Mode", choices=["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], value="max-autotune-no-cudagraphs", info="Torch compile mode") i2v_compile_dynamic = gr.Checkbox(label="Dynamic Shapes", value=False, info="Use dynamic shapes in compilation") i2v_compile_fullgraph = gr.Checkbox(label="Full Graph", value=False, info="Use full graph compilation") # Video to Video Tab with gr.Tab(id=2, label="Hunyuan-v2v") as v2v_tab: with gr.Row(): with gr.Column(scale=4): v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) v2v_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt (for SkyReels models)", value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", lines=3 ) with gr.Column(scale=1): v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") v2v_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): v2v_input = gr.Video(label="Input Video", format="mp4") v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and Height Inputs with gr.Row(): v2v_width = gr.Number(label="New Width", value=544, step=16) v2v_calc_height_btn = gr.Button("→") v2v_calc_width_btn = gr.Button("←") v2v_height = gr.Number(label="New Height", value=544, step=16) v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) with gr.Column(): v2v_output = gr.Gallery( label="Generated Videos", columns=[1], rows=[1], object_fit="contain", height="auto" ) v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") v2v_lora_weights = [] v2v_lora_multipliers = [] for i in range(4): with gr.Column(): v2v_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) v2v_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") v2v_model = gr.Dropdown( label="DiT Model", choices=get_dit_models("hunyuan"), value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") v2v_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) ### SKYREELS with gr.Tab(label="SkyReels-i2v") as skyreels_tab: with gr.Row(): with gr.Column(scale=4): skyreels_prompt = gr.Textbox( scale=3, label="Enter your prompt", value="A person walking on a beach at sunset", lines=5 ) skyreels_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt", value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", lines=3 ) with gr.Column(scale=1): skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") with gr.Row(): skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) skyreels_input_folder = gr.Textbox( label="Image Folder Path", placeholder="Path to folder containing images", visible=False ) skyreels_folder_status = gr.Textbox( label="Folder Status", placeholder="Status will appear here", interactive=False, visible=False ) skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") # Scale slider as percentage skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and height inputs with gr.Row(): skyreels_width = gr.Number(label="New Width", value=544, step=16) skyreels_calc_height_btn = gr.Button("→") skyreels_calc_width_btn = gr.Button("←") skyreels_height = gr.Number(label="New Height", value=544, step=16) skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) with gr.Column(): skyreels_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") # Add LoRA section for SKYREELS skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") skyreels_lora_weights = [] skyreels_lora_multipliers = [] for i in range(4): with gr.Column(): skyreels_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) skyreels_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") skyreels_model = gr.Dropdown( label="DiT Model", choices=get_dit_models("skyreels"), value="skyreels_hunyuan_i2v_bf16.safetensors", allow_custom_value=True, interactive=True ) skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) # WanX Image to Video Tab with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: with gr.Row(): with gr.Column(scale=4): wanx_prompt = gr.Textbox( scale=3, label="Enter your prompt", value="A person walking on a beach at sunset", lines=5 ) wanx_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt", value="", lines=3, info="Leave empty to use default negative prompt" ) with gr.Column(scale=1): wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") wanx_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): wanx_input = gr.Image(label="Input Image", type="filepath") with gr.Row(): wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) wanx_input_folder = gr.Textbox( label="Image Folder Path", placeholder="Path to folder containing images", visible=False ) wanx_folder_status = gr.Textbox( label="Folder Status", placeholder="Status will appear here", interactive=False, visible=False ) wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) # Width and height display with gr.Row(): wanx_width = gr.Number(label="Width", value=832, interactive=True) wanx_calc_height_btn = gr.Button("→") wanx_calc_width_btn = gr.Button("←") wanx_height = gr.Number(label="Height", value=480, interactive=True) wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, info="Recommended: 3.0 for 480p, 5.0 for others") wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) with gr.Column(): wanx_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") wanx_extend_btn = gr.Button("Extend Video") wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, label="Frames to Check from End", info="Number of frames from the end to check for sharpness") wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) # Add a new button for directly extending with the trimmed video wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") # Add LoRA section for WanX-i2v similar to other tabs wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") wanx_lora_weights = [] wanx_lora_multipliers = [] for i in range(4): with gr.Column(): wanx_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) wanx_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) wanx_task = gr.Dropdown( label="Task", choices=["i2v-14B"], value="i2v-14B", info="Currently only i2v-14B is supported" ) wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") wanx_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) #WanX-t2v Tab # WanX Text to Video Tab with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: with gr.Row(): with gr.Column(scale=4): wanx_t2v_prompt = gr.Textbox( scale=3, label="Enter your prompt", value="A person walking on a beach at sunset", lines=5 ) wanx_t2v_negative_prompt = gr.Textbox( scale=3, label="Negative Prompt", value="", lines=3, info="Leave empty to use default negative prompt" ) with gr.Column(scale=1): wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) with gr.Column(scale=2): wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") with gr.Row(): wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") with gr.Row(): with gr.Column(): with gr.Row(): wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, info="Recommended: 3.0 for I2V with 480p, 5.0 for others") wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) with gr.Column(): wanx_t2v_output = gr.Gallery( label="Generated Videos (Click to select)", columns=[2], rows=[2], object_fit="contain", height="auto", show_label=True, elem_id="gallery", allow_preview=True, preview=True ) wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") # Add LoRA section for WanX-t2v wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") wanx_t2v_lora_weights = [] wanx_t2v_lora_multipliers = [] for i in range(4): with gr.Column(): wanx_t2v_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) wanx_t2v_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) wanx_t2v_task = gr.Dropdown( label="Task", choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], value="t2v-14B", info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" ) wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") with gr.Row(): wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, info="Max 39 for 14B model, 29 for 1.3B model") wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) #Video Info Tab with gr.Tab("Video Info") as video_info_tab: with gr.Row(): video_input = gr.Video(label="Upload Video", interactive=True) metadata_output = gr.JSON(label="Generation Parameters") with gr.Row(): send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") with gr.Row(): status = gr.Textbox(label="Status", interactive=False) #Merge Model's tab with gr.Tab("Convert LoRA") as convert_lora_tab: def suggest_output_name(file_obj) -> str: """Generate suggested output name from input file""" if not file_obj: return "" # Get input filename without extension and add MUSUBI base_name = os.path.splitext(os.path.basename(file_obj.name))[0] return f"{base_name}_MUSUBI" def convert_lora(input_file, output_name: str, target_format: str) -> str: """Convert LoRA file to specified format""" try: if not input_file: return "Error: No input file selected" # Ensure output directory exists os.makedirs("lora", exist_ok=True) # Construct output path output_path = os.path.join("lora", f"{output_name}.safetensors") # Build command cmd = [ sys.executable, "convert_lora.py", "--input", input_file.name, "--output", output_path, "--target", target_format ] print(f"Converting {input_file.name} to {output_path}") # Execute conversion result = subprocess.run( cmd, capture_output=True, text=True, check=True ) if os.path.exists(output_path): return f"Successfully converted LoRA to {output_path}" else: return "Error: Output file not created" except subprocess.CalledProcessError as e: return f"Error during conversion: {e.stderr}" except Exception as e: return f"Error: {str(e)}" with gr.Row(): input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") format_radio = gr.Radio( choices=["default", "other"], value="default", label="Target Format", info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" ) with gr.Row(): convert_btn = gr.Button("Convert LoRA", variant="primary") status_output = gr.Textbox(label="Status", interactive=False) # Automatically update output name when file is selected input_file.change( fn=suggest_output_name, inputs=[input_file], outputs=[output_name] ) # Handle conversion convert_btn.click( fn=convert_lora, inputs=[input_file, output_name, format_radio], outputs=status_output ) with gr.Tab("Model Merging") as model_merge_tab: with gr.Row(): with gr.Column(): # Model selection dit_model = gr.Dropdown( label="Base DiT Model", choices=["mp_rank_00_model_states.pt"], value="mp_rank_00_model_states.pt", allow_custom_value=True, interactive=True ) merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") with gr.Row(): with gr.Column(): # Output model name output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) merge_btn = gr.Button("Merge Models", variant="primary") merge_status = gr.Textbox(label="Status", interactive=False) with gr.Row(): # LoRA selection section (similar to Text2Video) merge_lora_weights = [] merge_lora_multipliers = [] for i in range(4): with gr.Column(): merge_lora_weights.append(gr.Dropdown( label=f"LoRA {i+1}", choices=get_lora_options(), value="None", allow_custom_value=True, interactive=True )) merge_lora_multipliers.append(gr.Slider( label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0 )) with gr.Row(): merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") #Video Extension wanx_send_last_frame_btn.click( fn=send_last_frame_handler, inputs=[wanx_output, wanx_i2v_selected_index], outputs=[wanx_input, wanx_base_video] ) wanx_extend_btn.click( fn=prepare_for_batch_extension, inputs=[wanx_input, wanx_base_video, wanx_batch_size], outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] ).then( fn=wanx_batch_handler, inputs=[ gr.Checkbox(value=False), # Not using random folder wanx_prompt, wanx_negative_prompt, wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_flow_shift, wanx_guidance_scale, wanx_seed, wanx_batch_size, wanx_input_folder, # Not used but needed for function signature wanx_task, wanx_dit_path, wanx_vae_path, wanx_t5_path, wanx_clip_path, wanx_save_path, wanx_output_type, wanx_sample_solver, wanx_exclude_single_blocks, wanx_attn_mode, wanx_block_swap, wanx_fp8, wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, *wanx_lora_multipliers, wanx_input # Include input image ], outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] ).then( fn=concat_batch_videos, inputs=[wanx_base_video, wanx_output, wanx_save_path], outputs=[wanx_output, wanx_progress_text] ) # Extract and send sharpest frame to input wanx_send_sharpest_frame_btn.click( fn=send_sharpest_frame_handler, inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] ) # Trim video to sharpest frame and prepare for extension wanx_trim_and_extend_btn.click( fn=trim_and_prepare_for_extension, inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] ).then( fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], outputs=[wanx_base_video, wanx_sharpest_frame_status] ) # Event handler for extending with the trimmed video wanx_extend_with_trimmed_btn.click( fn=prepare_for_batch_extension, inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] ).then( fn=wanx_batch_handler, inputs=[ gr.Checkbox(value=False), # Not using random folder wanx_prompt, wanx_negative_prompt, wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_flow_shift, wanx_guidance_scale, wanx_seed, wanx_batch_size, wanx_input_folder, # Not used but needed for function signature wanx_task, wanx_dit_path, wanx_vae_path, wanx_t5_path, wanx_clip_path, wanx_save_path, wanx_output_type, wanx_sample_solver, wanx_exclude_single_blocks, wanx_attn_mode, wanx_block_swap, wanx_fp8, wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, *wanx_lora_multipliers, wanx_input # Include input image ], outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] ).then( fn=concat_batch_videos, inputs=[wanx_trimmed_video_path, wanx_output, wanx_save_path], outputs=[wanx_output, wanx_progress_text] ) #Video Info def handle_send_to_wanx_tab(metadata, target_tab): """Common handler for sending video parameters to WanX tabs""" if not metadata: return "No parameters to send", {} # Tab names for clearer messages tab_names = { 'wanx_i2v': 'WanX-i2v', 'wanx_t2v': 'WanX-t2v' } # Just pass through all parameters - we'll use them in the .then() function return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata def change_to_wanx_i2v_tab(): return gr.Tabs(selected=4) # WanX-i2v tab index def change_to_wanx_t2v_tab(): return gr.Tabs(selected=5) # WanX-t2v tab index send_to_wanx_i2v_btn.click( fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_i2v'), inputs=[metadata_output], outputs=[status, params_state] ).then( lambda params: [ params.get("prompt", ""), params.get("width", 832), params.get("height", 480), params.get("video_length", 81), params.get("fps", 16), params.get("infer_steps", 40), params.get("seed", -1), params.get("flow_shift", 3.0), params.get("guidance_scale", 5.0), params.get("attn_mode", "sdpa"), params.get("block_swap", 0), params.get("task", "i2v-14B") ] if params else [gr.update()]*12, inputs=params_state, outputs=[ wanx_prompt, wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_seed, wanx_flow_shift, wanx_guidance_scale, wanx_attn_mode, wanx_block_swap, wanx_task ] ).then( fn=change_to_wanx_i2v_tab, inputs=None, outputs=[tabs] ) # 3. Update the WanX-t2v button handler send_to_wanx_t2v_btn.click( fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), inputs=[metadata_output], outputs=[status, params_state] ).then( lambda params: [ params.get("prompt", ""), params.get("width", 832), params.get("height", 480), params.get("video_length", 81), params.get("fps", 16), params.get("infer_steps", 50), params.get("seed", -1), params.get("flow_shift", 5.0), params.get("guidance_scale", 5.0), params.get("attn_mode", "sdpa"), params.get("block_swap", 0) ] if params else [gr.update()]*11, inputs=params_state, outputs=[ wanx_t2v_prompt, wanx_t2v_width, wanx_t2v_height, wanx_t2v_video_length, wanx_t2v_fps, wanx_t2v_infer_steps, wanx_t2v_seed, wanx_t2v_flow_shift, wanx_t2v_guidance_scale, wanx_t2v_attn_mode, wanx_t2v_block_swap ] ).then( fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] ) #text to video def change_to_tab_one(): return gr.Tabs(selected=1) #This will navigate #video to video def change_to_tab_two(): return gr.Tabs(selected=2) #This will navigate def change_to_skyreels_tab(): return gr.Tabs(selected=3) #SKYREELS TAB!!! # Add state management for dimensions def sync_skyreels_dimensions(width, height): return gr.update(value=width), gr.update(value=height) # Add this function to update the LoRA dropdowns in the SKYREELS tab def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: new_choices = get_lora_options(lora_folder) weights = current_values[:4] multipliers = current_values[4:8] results = [] for i in range(4): weight = weights[i] if i < len(weights) else "None" multiplier = multipliers[i] if i < len(multipliers) else 1.0 if weight not in new_choices: weight = "None" results.extend([ gr.update(choices=new_choices, value=weight), gr.update(value=multiplier) ]) return results # Add this function to update the models dropdown in the SKYREELS tab def update_skyreels_model_dropdown(dit_folder: str) -> Dict: models = get_dit_models(dit_folder) return gr.update(choices=models, value=models[0] if models else None) # Add event handler for model dropdown refresh skyreels_dit_folder.change( fn=update_skyreels_model_dropdown, inputs=[skyreels_dit_folder], outputs=[skyreels_model] ) # Add handlers for the refresh button skyreels_refresh_btn.click( fn=update_skyreels_lora_dropdowns, inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] ) # Skyreels dimension handling def calculate_skyreels_width(height, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 16) * 16 return gr.update(value=new_width) def calculate_skyreels_height(width, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 16) * 16 return gr.update(value=new_height) def update_skyreels_from_scale(scale, original_dims): if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 16) * 16 new_h = math.floor((orig_h * scale / 100) / 16) * 16 return gr.update(value=new_w), gr.update(value=new_h) def update_skyreels_dimensions(image): if image is None: return "", gr.update(value=544), gr.update(value=544) img = Image.open(image) w, h = img.size w = (w // 16) * 16 h = (h // 16) * 16 return f"{w}x{h}", w, h def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: return evt.index def send_skyreels_to_v2v( gallery: list, prompt: str, selected_index: int, width: int, height: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, cfg_scale: float, lora1: str, lora2: str, lora3: str, lora4: str, lora1_multiplier: float, lora2_multiplier: float, lora3_multiplier: float, lora4_multiplier: float, negative_prompt: str = "" # Add this parameter ) -> Tuple: if not gallery or selected_index is None or selected_index >= len(gallery): return (None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, negative_prompt) # Add negative_prompt to return selected_item = gallery[selected_index] if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item if isinstance(video_path, tuple): video_path = video_path[0] return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, negative_prompt) # Add negative_prompt to return # Add event handlers for the SKYREELS tab skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) # Image input handling skyreels_input.change( fn=update_skyreels_dimensions, inputs=[skyreels_input], outputs=[skyreels_original_dims, skyreels_width, skyreels_height] ) skyreels_scale_slider.change( fn=update_skyreels_from_scale, inputs=[skyreels_scale_slider, skyreels_original_dims], outputs=[skyreels_width, skyreels_height] ) skyreels_calc_width_btn.click( fn=calculate_skyreels_width, inputs=[skyreels_height, skyreels_original_dims], outputs=[skyreels_width] ) skyreels_calc_height_btn.click( fn=calculate_skyreels_height, inputs=[skyreels_width, skyreels_original_dims], outputs=[skyreels_height] ) # Handle checkbox visibility toggling skyreels_use_random_folder.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), inputs=[skyreels_use_random_folder], outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] ) # Validate folder button click handler skyreels_validate_folder_btn.click( fn=lambda folder: get_random_image_from_folder(folder)[1], inputs=[skyreels_input_folder], outputs=[skyreels_folder_status] ) skyreels_use_random_folder.change( fn=lambda x: gr.update(visible=x), inputs=[skyreels_use_random_folder], outputs=[skyreels_validate_folder_btn] ) # Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on skyreels_generate_btn.click( fn=batch_handler, inputs=[ skyreels_use_random_folder, # Rest of the arguments skyreels_prompt, skyreels_negative_prompt, skyreels_width, skyreels_height, skyreels_video_length, skyreels_fps, skyreels_infer_steps, skyreels_seed, skyreels_flow_shift, skyreels_guidance_scale, skyreels_embedded_cfg_scale, skyreels_batch_size, skyreels_input_folder, skyreels_dit_folder, skyreels_model, skyreels_vae, skyreels_te1, skyreels_te2, skyreels_save_path, skyreels_output_type, skyreels_attn_mode, skyreels_block_swap, skyreels_exclude_single_blocks, skyreels_use_split_attn, skyreels_use_fp8, skyreels_split_uncond, skyreels_lora_folder, *skyreels_lora_weights, *skyreels_lora_multipliers, skyreels_input # Add the input image path ], outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[skyreels_batch_size], outputs=skyreels_selected_index ) # Gallery selection handling skyreels_output.select( fn=handle_skyreels_gallery_select, outputs=skyreels_selected_index ) # Send to Video2Video handler skyreels_send_to_v2v_btn.click( fn=send_skyreels_to_v2v, inputs=[ skyreels_output, skyreels_prompt, skyreels_selected_index, skyreels_width, skyreels_height, skyreels_video_length, skyreels_fps, skyreels_infer_steps, skyreels_seed, skyreels_flow_shift, skyreels_guidance_scale ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) # Refresh button handler skyreels_refresh_outputs = [skyreels_model] for i in range(4): skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) skyreels_refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, outputs=skyreels_refresh_outputs ) def calculate_v2v_width(height, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 return gr.update(value=new_width) def calculate_v2v_height(width, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 return gr.update(value=new_height) def update_v2v_from_scale(scale, original_dims): if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 return gr.update(value=new_w), gr.update(value=new_h) def update_v2v_dimensions(video): if video is None: return "", gr.update(value=544), gr.update(value=544) cap = cv2.VideoCapture(video) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() # Make dimensions divisible by 16 w = (w // 16) * 16 h = (h // 16) * 16 return f"{w}x{h}", w, h # Event Handlers for Video to Video Tab v2v_input.change( fn=update_v2v_dimensions, inputs=[v2v_input], outputs=[v2v_original_dims, v2v_width, v2v_height] ) v2v_scale_slider.change( fn=update_v2v_from_scale, inputs=[v2v_scale_slider, v2v_original_dims], outputs=[v2v_width, v2v_height] ) v2v_calc_width_btn.click( fn=calculate_v2v_width, inputs=[v2v_height, v2v_original_dims], outputs=[v2v_width] ) v2v_calc_height_btn.click( fn=calculate_v2v_height, inputs=[v2v_width, v2v_original_dims], outputs=[v2v_height] ) ##Image 2 video dimension logic def calculate_width(height, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 return gr.update(value=new_width) def calculate_height(width, original_dims): if not original_dims: return gr.update() orig_w, orig_h = map(int, original_dims.split('x')) aspect_ratio = orig_w / orig_h new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 return gr.update(value=new_height) def update_from_scale(scale, original_dims): if not original_dims: return gr.update(), gr.update() orig_w, orig_h = map(int, original_dims.split('x')) new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 return gr.update(value=new_w), gr.update(value=new_h) def update_dimensions(image): if image is None: return "", gr.update(value=544), gr.update(value=544) img = Image.open(image) w, h = img.size # Make dimensions divisible by 16 w = (w // 16) * 16 # Changed from 8 to 16 h = (h // 16) * 16 # Changed from 8 to 16 return f"{w}x{h}", w, h i2v_input.change( fn=update_dimensions, inputs=[i2v_input], outputs=[original_dims, width, height] ) scale_slider.change( fn=update_from_scale, inputs=[scale_slider, original_dims], outputs=[width, height] ) calc_width_btn.click( fn=calculate_width, inputs=[height, original_dims], outputs=[width] ) calc_height_btn.click( fn=calculate_height, inputs=[width, original_dims], outputs=[height] ) # Function to get available DiT models def get_dit_models(dit_folder: str) -> List[str]: if not os.path.exists(dit_folder): return ["mp_rank_00_model_states.pt"] models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] models.sort(key=str.lower) return models if models else ["mp_rank_00_model_states.pt"] # Function to perform model merging def merge_models( dit_folder: str, dit_model: str, output_model: str, exclude_single_blocks: bool, merge_lora_folder: str, *lora_params # Will contain both weights and multipliers ) -> str: try: # Separate weights and multipliers num_loras = len(lora_params) // 2 weights = list(lora_params[:num_loras]) multipliers = list(lora_params[num_loras:]) # Filter out "None" selections valid_loras = [] for weight, mult in zip(weights, multipliers): if weight and weight != "None": valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) if not valid_loras: return "No LoRA models selected for merging" # Create output path in the dit folder os.makedirs(dit_folder, exist_ok=True) output_path = os.path.join(dit_folder, output_model) # Prepare command cmd = [ sys.executable, "merge_lora.py", "--dit", os.path.join(dit_folder, dit_model), "--save_merged_model", output_path ] # Add LoRA weights and multipliers weights = [weight for weight, _ in valid_loras] multipliers = [str(mult) for _, mult in valid_loras] cmd.extend(["--lora_weight"] + weights) cmd.extend(["--lora_multiplier"] + multipliers) if exclude_single_blocks: cmd.append("--exclude_single_blocks") # Execute merge operation result = subprocess.run( cmd, capture_output=True, text=True, check=True ) if os.path.exists(output_path): return f"Successfully merged model and saved to {output_path}" else: return "Error: Output file not created" except subprocess.CalledProcessError as e: return f"Error during merging: {e.stderr}" except Exception as e: return f"Error: {str(e)}" # Update DiT model dropdown def update_dit_dropdown(dit_folder: str) -> Dict: models = get_dit_models(dit_folder) return gr.update(choices=models, value=models[0] if models else None) # Connect events merge_btn.click( fn=merge_models, inputs=[ dit_folder, dit_model, output_model, exclude_single_blocks, merge_lora_folder, *merge_lora_weights, *merge_lora_multipliers ], outputs=merge_status ) # Refresh buttons for both DiT and LoRA dropdowns merge_refresh_btn.click( fn=lambda f: update_dit_dropdown(f), inputs=[dit_folder], outputs=[dit_model] ) # LoRA refresh handling merge_refresh_outputs = [] for i in range(4): merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) merge_refresh_btn.click( fn=update_lora_dropdowns, inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, outputs=merge_refresh_outputs ) # Event handlers prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) stop_btn.click(fn=lambda: stop_event.set(), queue=False) v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) #Image_to_Video def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters img = Image.open(image_path) # Resize to the specified dimensions img_resized = img.resize((width, height), Image.LANCZOS) temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") img_resized.save(temp_image_path) # Rest of function remains the same frame_rate = 24 duration = frames / frame_rate command = [ "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", "-t", str(duration), "-pix_fmt", "yuv420p", "-vf", f"fps={frame_rate}", output_path ] try: subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) print(f"Video saved to {output_path}") return True except subprocess.CalledProcessError as e: print(f"An error occurred while creating the video: {e}") return False finally: # Clean up the temporary image file if os.path.exists(temp_image_path): os.remove(temp_image_path) img.close() # Make sure to close the image file explicitly def generate_from_image( image_path, prompt, width, height, video_length, fps, infer_steps, seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder, strength, batch_size, *lora_params ): """Generate video from input image with progressive updates""" global stop_event stop_event.clear() # Create temporary video path temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") try: # Convert image to video if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): yield [], "Failed to create temporary video", "Error in video creation" return # Ensure video is fully written before proceeding time.sleep(1) if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: yield [], "Failed to create temporary video", "Temporary video file is empty or missing" return # Get video dimensions try: probe = ffmpeg.probe(temp_video_path) video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) if video_stream is None: raise ValueError("No video stream found") width = int(video_stream['width']) height = int(video_stream['height']) except Exception as e: yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" return # Generate the video using the temporary file try: generator = process_single_video( prompt, width, height, batch_size, video_length, fps, infer_steps, seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder, *lora_params, video_path=temp_video_path, strength=strength ) # Forward all generator updates for videos, batch_text, progress_text in generator: yield videos, batch_text, progress_text except Exception as e: yield [], f"Error in video generation: {str(e)}", "Generation error" return except Exception as e: yield [], f"Unexpected error: {str(e)}", "Error occurred" return finally: # Clean up temporary file try: if os.path.exists(temp_video_path): os.remove(temp_video_path) except Exception: pass # Ignore cleanup errors # Add event handlers i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) def handle_i2v_gallery_select(evt: gr.SelectData) -> int: """Track selected index when I2V gallery item is clicked""" return evt.index def send_i2v_to_v2v( gallery: list, prompt: str, selected_index: int, width: int, height: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, cfg_scale: float, lora1: str, lora2: str, lora3: str, lora4: str, lora1_multiplier: float, lora2_multiplier: float, lora3_multiplier: float, lora4_multiplier: float ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: """Send the selected video and parameters from Image2Video tab to Video2Video tab""" if not gallery or selected_index is None or selected_index >= len(gallery): return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier selected_item = gallery[selected_index] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] # Use the original width and height without doubling return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) # Generate button handler i2v_generate_btn.click( fn=process_hunyuani2v_batch, inputs=[ i2v_prompt, width, height, i2v_batch_size, i2v_video_length, i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_negative_prompt, i2v_cfg_scale, i2v_split_uncond, i2v_use_fp8, i2v_clip_vision_path, i2v_stability, i2v_fp8_fast, i2v_compile, i2v_compile_backend, i2v_compile_mode, i2v_compile_dynamic, i2v_compile_fullgraph ], outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[i2v_batch_size], outputs=i2v_selected_index ) # Send to Video2Video i2v_output.select( fn=handle_i2v_gallery_select, outputs=i2v_selected_index ) i2v_send_to_v2v_btn.click( fn=send_i2v_to_v2v, inputs=[ i2v_output, i2v_prompt, i2v_selected_index, width, height, i2v_video_length, i2v_fps, i2v_infer_steps, i2v_seed, i2v_flow_shift, i2v_cfg_scale ] + i2v_lora_weights + i2v_lora_multipliers, outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale ] + v2v_lora_weights + v2v_lora_multipliers ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) #Video Info def clean_video_path(video_path) -> str: """Extract clean video path from Gradio's various return formats""" print(f"Input video_path: {video_path}, type: {type(video_path)}") if isinstance(video_path, dict): path = video_path.get("name", "") elif isinstance(video_path, (tuple, list)): path = video_path[0] elif isinstance(video_path, str): path = video_path else: path = "" print(f"Cleaned path: {path}") return path def handle_video_upload(video_path: str) -> Dict: """Handle video upload and metadata extraction""" if not video_path: return {}, "No video uploaded" metadata = extract_video_metadata(video_path) if not metadata: return {}, "No metadata found in video" return metadata, "Metadata extracted successfully" def get_video_info(video_path: str) -> dict: try: probe = ffmpeg.probe(video_path) video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') width = int(video_info['width']) height = int(video_info['height']) fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 # Calculate total frames duration = float(probe['format']['duration']) total_frames = int(duration * fps) # Ensure video length does not exceed 201 frames if total_frames > 201: total_frames = 201 duration = total_frames / fps # Adjust duration accordingly return { 'width': width, 'height': height, 'fps': fps, 'total_frames': total_frames, 'duration': duration # Might be useful in some contexts } except Exception as e: print(f"Error extracting video info: {e}") return {} def extract_video_details(video_path: str) -> Tuple[dict, str]: metadata = extract_video_metadata(video_path) video_details = get_video_info(video_path) # Combine metadata with video details for key, value in video_details.items(): if key not in metadata: metadata[key] = value # Ensure video length does not exceed 201 frames if 'video_length' in metadata: metadata['video_length'] = min(metadata['video_length'], 201) else: metadata['video_length'] = min(video_details.get('total_frames', 0), 201) # Return both the updated metadata and a status message return metadata, "Video details extracted successfully" def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: """Create parameter mapping for target tab""" if not metadata: return "No parameters to send", {} tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" try: mapping = create_parameter_transfer_map(metadata, target_tab) return f"Parameters ready for {tab_name}", mapping except Exception as e: return f"Error: {str(e)}", {} video_input.upload( fn=extract_video_details, inputs=video_input, outputs=[metadata_output, status] ) send_to_t2v_btn.click( fn=lambda m: send_parameters_to_tab(m, "t2v"), inputs=metadata_output, outputs=[status, params_state] ).then( fn=change_to_tab_one, inputs=None, outputs=[tabs] ).then( lambda params: [ params.get("prompt", ""), params.get("width", 544), params.get("height", 544), params.get("batch_size", 1), params.get("video_length", 25), params.get("fps", 24), params.get("infer_steps", 30), params.get("seed", -1), params.get("model", "hunyuan/mp_rank_00_model_states.pt"), params.get("vae", "hunyuan/pytorch_model.pt"), params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), params.get("te2", "hunyuan/clip_l.safetensors"), params.get("save_path", "outputs"), params.get("flow_shift", 11.0), params.get("cfg_scale", 7.0), params.get("output_type", "video"), params.get("attn_mode", "sdpa"), params.get("block_swap", "0"), *[params.get(f"lora{i+1}", "") for i in range(4)], *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] ] if params else [gr.update()]*26, inputs=params_state, outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap] + lora_weights + lora_multipliers ) # Text to Video generation generate_btn.click( fn=process_batch, inputs=[ prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 ], outputs=[video_output, batch_progress, progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[batch_size], outputs=selected_index ) # Update gallery selection handling def handle_gallery_select(evt: gr.SelectData) -> int: return evt.index # Track selected index when gallery item is clicked video_output.select( fn=handle_gallery_select, outputs=selected_index ) # Track selected index when Video2Video gallery item is clicked def handle_v2v_gallery_select(evt: gr.SelectData) -> int: """Handle gallery selection without automatically updating the input""" return evt.index # Update the gallery selection event v2v_output.select( fn=handle_v2v_gallery_select, outputs=v2v_selected_index ) # Send button handler with gallery selection def handle_send_button( gallery: list, prompt: str, idx: int, width: int, height: int, batch_size: int, video_length: int, fps: int, infer_steps: int, seed: int, flow_shift: float, cfg_scale: float, lora1: str, lora2: str, lora3: str, lora4: str, lora1_multiplier: float, lora2_multiplier: float, lora3_multiplier: float, lora4_multiplier: float ) -> tuple: if not gallery or idx is None or idx >= len(gallery): return (None, "", width, height, batch_size, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, "") # Add empty string for negative_prompt in the return values # Auto-select first item if only one exists and no selection made if idx is None and len(gallery) == 1: idx = 0 selected_item = gallery[idx] # Handle different gallery item formats if isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, (tuple, list)): video_path = selected_item[0] else: video_path = selected_item # Final cleanup for Gradio Video component if isinstance(video_path, tuple): video_path = video_path[0] return ( str(video_path), prompt, width, height, batch_size, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, "" # Add empty string for negative_prompt ) send_t2v_to_v2v_btn.click( fn=handle_send_button, inputs=[ video_output, prompt, selected_index, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, seed, flow_shift, cfg_scale ] + lora_weights + lora_multipliers, # Remove the string here outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: """Handle both parameters and video transfer""" status_msg, params = send_parameters_to_tab(metadata, "v2v") return status_msg, params, video_path def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: """Handle both parameters and video transfer from Video Info to V2V tab""" if not video_path: return "No video selected", {}, None status_msg, params = send_parameters_to_tab(metadata, "v2v") # Just return the path directly return status_msg, params, video_path # Send button click handler send_to_v2v_btn.click( fn=handle_info_to_v2v, inputs=[metadata_output, video_input], outputs=[status, params_state, v2v_input] ).then( lambda params: [ params.get("v2v_prompt", ""), params.get("v2v_width", 544), params.get("v2v_height", 544), params.get("v2v_batch_size", 1), params.get("v2v_video_length", 25), params.get("v2v_fps", 24), params.get("v2v_infer_steps", 30), params.get("v2v_seed", -1), params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), params.get("v2v_vae", "hunyuan/pytorch_model.pt"), params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), params.get("v2v_te2", "hunyuan/clip_l.safetensors"), params.get("v2v_save_path", "outputs"), params.get("v2v_flow_shift", 11.0), params.get("v2v_cfg_scale", 7.0), params.get("v2v_output_type", "video"), params.get("v2v_attn_mode", "sdpa"), params.get("v2v_block_swap", "0"), *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] ] if params else [gr.update()] * 26, inputs=params_state, outputs=[ v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, v2v_block_swap ] + v2v_lora_weights + v2v_lora_multipliers ).then( lambda: print(f"Tabs object: {tabs}"), # Debug print outputs=None ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) # Handler for sending selected video from Video2Video gallery to input def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: """Send the currently selected video in V2V gallery to V2V input""" if not gallery or idx is None or idx >= len(gallery): return None, "" selected_item = gallery[idx] video_path = None # Handle different gallery item formats if isinstance(selected_item, tuple): video_path = selected_item[0] # Gallery returns (path, caption) elif isinstance(selected_item, dict): video_path = selected_item.get("name", selected_item.get("data", None)) elif isinstance(selected_item, str): video_path = selected_item if not video_path: return None, "" # Check if the file exists and is accessible if not os.path.exists(video_path): print(f"Warning: Video file not found at {video_path}") return None, "" return video_path, prompt v2v_send_to_input_btn.click( fn=handle_v2v_send_button, inputs=[v2v_output, v2v_prompt, v2v_selected_index], outputs=[v2v_input, v2v_prompt] ).then( lambda: gr.update(visible=True), # Ensure the video input is visible outputs=v2v_input ) # Video to Video generation v2v_generate_btn.click( fn=process_batch, inputs=[ v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 ], outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[v2v_batch_size], outputs=v2v_selected_index ) refresh_outputs = [model] # Add model dropdown to outputs for i in range(4): refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, outputs=refresh_outputs ) # Image2Video refresh i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs for i in range(4): i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) i2v_refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, outputs=i2v_refresh_outputs ) # Video2Video refresh v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs for i in range(4): v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) v2v_refresh_btn.click( fn=update_dit_and_lora_dropdowns, inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, outputs=v2v_refresh_outputs ) # WanX-i2v tab connections wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) # Image input handling for WanX-i2v wanx_input.change( fn=update_wanx_image_dimensions, inputs=[wanx_input], outputs=[wanx_original_dims, wanx_width, wanx_height] ) # Scale slider handling for WanX-i2v wanx_scale_slider.change( fn=update_wanx_from_scale, inputs=[wanx_scale_slider, wanx_original_dims], outputs=[wanx_width, wanx_height] ) # Width/height calculation buttons for WanX-i2v wanx_calc_width_btn.click( fn=calculate_wanx_width, inputs=[wanx_height, wanx_original_dims], outputs=[wanx_width] ) wanx_calc_height_btn.click( fn=calculate_wanx_height, inputs=[wanx_width, wanx_original_dims], outputs=[wanx_height] ) # Add visibility toggle for the folder input components wanx_use_random_folder.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), inputs=[wanx_use_random_folder], outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] ) # Validate folder button handler wanx_validate_folder_btn.click( fn=lambda folder: get_random_image_from_folder(folder)[1], inputs=[wanx_input_folder], outputs=[wanx_folder_status] ) # Flow shift recommendation buttons wanx_recommend_flow_btn.click( fn=recommend_wanx_flow_shift, inputs=[wanx_width, wanx_height], outputs=[wanx_flow_shift] ) wanx_t2v_recommend_flow_btn.click( fn=recommend_wanx_flow_shift, inputs=[wanx_t2v_width, wanx_t2v_height], outputs=[wanx_t2v_flow_shift] ) # Generate button handler wanx_generate_btn.click( fn=wanx_batch_handler, inputs=[ wanx_use_random_folder, wanx_prompt, wanx_negative_prompt, wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_flow_shift, wanx_guidance_scale, wanx_seed, wanx_batch_size, wanx_input_folder, wanx_task, wanx_dit_path, wanx_vae_path, wanx_t5_path, wanx_clip_path, wanx_save_path, wanx_output_type, wanx_sample_solver, wanx_exclude_single_blocks, wanx_attn_mode, wanx_block_swap, wanx_fp8, wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, *wanx_lora_multipliers, wanx_input # Include input image path for non-batch mode ], outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[wanx_batch_size], outputs=wanx_i2v_selected_index # Update to use correct state ) # Add refresh button handler for WanX-i2v tab wanx_refresh_outputs = [] for i in range(4): wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) wanx_refresh_btn.click( fn=update_lora_dropdowns, inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, outputs=wanx_refresh_outputs ) # Gallery selection handling wanx_output.select( fn=handle_wanx_gallery_select, inputs=[wanx_output], outputs=[wanx_i2v_selected_index, wanx_base_video] ) # Send to Video2Video handler wanx_send_to_v2v_btn.click( fn=send_wanx_to_v2v, inputs=[ wanx_output, # Gallery with videos wanx_prompt, # Prompt text wanx_i2v_selected_index, # Use the correct selected index state wanx_width, wanx_height, wanx_video_length, wanx_fps, wanx_infer_steps, wanx_seed, wanx_flow_shift, wanx_guidance_scale, wanx_negative_prompt ], outputs=[ v2v_input, # Video input in V2V tab v2v_prompt, # Prompt in V2V tab v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale, v2v_negative_prompt ] ).then( fn=change_to_tab_two, # Function to switch to Video2Video tab inputs=None, outputs=[tabs] ) # Add state for T2V tab selected index wanx_t2v_selected_index = gr.State(value=None) # Connect prompt token counter wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) # Stop button handler wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) # Flow shift recommendation button wanx_t2v_recommend_flow_btn.click( fn=recommend_wanx_flow_shift, inputs=[wanx_t2v_width, wanx_t2v_height], outputs=[wanx_t2v_flow_shift] ) # Task change handler to update CLIP visibility and path def update_clip_visibility(task): is_i2v = "i2v" in task return gr.update(visible=is_i2v) wanx_t2v_task.change( fn=update_clip_visibility, inputs=[wanx_t2v_task], outputs=[wanx_t2v_clip_path] ) # Generate button handler for T2V wanx_t2v_generate_btn.click( fn=wanx_generate_video_batch, inputs=[ wanx_t2v_prompt, wanx_t2v_negative_prompt, wanx_t2v_width, wanx_t2v_height, wanx_t2v_video_length, wanx_t2v_fps, wanx_t2v_infer_steps, wanx_t2v_flow_shift, wanx_t2v_guidance_scale, wanx_t2v_seed, wanx_t2v_task, wanx_t2v_dit_path, wanx_t2v_vae_path, wanx_t2v_t5_path, wanx_t2v_clip_path, wanx_t2v_save_path, wanx_t2v_output_type, wanx_t2v_sample_solver, wanx_t2v_exclude_single_blocks, wanx_t2v_attn_mode, wanx_t2v_block_swap, wanx_t2v_fp8, wanx_t2v_fp8_t5, wanx_t2v_lora_folder, *wanx_t2v_lora_weights, *wanx_t2v_lora_multipliers, wanx_t2v_batch_size, # input_image is now optional and not included here ], outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], queue=True ).then( fn=lambda batch_size: 0 if batch_size == 1 else None, inputs=[wanx_t2v_batch_size], outputs=wanx_t2v_selected_index ) # Add refresh button handler for WanX-t2v tab wanx_t2v_refresh_outputs = [] for i in range(4): wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) wanx_t2v_refresh_btn.click( fn=update_lora_dropdowns, inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, outputs=wanx_t2v_refresh_outputs ) # Gallery selection handling wanx_t2v_output.select( fn=handle_wanx_t2v_gallery_select, outputs=wanx_t2v_selected_index ) # Send to Video2Video handler wanx_t2v_send_to_v2v_btn.click( fn=send_wanx_t2v_to_v2v, inputs=[ wanx_t2v_output, wanx_t2v_prompt, wanx_t2v_selected_index, wanx_t2v_width, wanx_t2v_height, wanx_t2v_video_length, wanx_t2v_fps, wanx_t2v_infer_steps, wanx_t2v_seed, wanx_t2v_flow_shift, wanx_t2v_guidance_scale, wanx_t2v_negative_prompt ], outputs=[ v2v_input, v2v_prompt, v2v_width, v2v_height, v2v_video_length, v2v_fps, v2v_infer_steps, v2v_seed, v2v_flow_shift, v2v_cfg_scale, v2v_negative_prompt ] ).then( fn=change_to_tab_two, inputs=None, outputs=[tabs] ) demo.queue().launch(server_name="0.0.0.0", share=False)