import argparse from datetime import datetime import gc import random import os import re import time import math from typing import Tuple, Optional, List, Union, Any from pathlib import Path # Added for glob_images in V2V import torch import accelerate from accelerate import Accelerator from safetensors.torch import load_file, save_file from safetensors import safe_open from PIL import Image import cv2 # Added for V2V video loading/resizing import numpy as np # Added for V2V video processing import torchvision.transforms.functional as TF from tqdm import tqdm from networks import lora_wan from utils.safetensors_utils import mem_eff_save_file, load_safetensors from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES import wan from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype from wan.modules.vae import WanVAE from wan.modules.t5 import T5EncoderModel from wan.modules.clip import CLIPModel from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler try: from lycoris.kohya import create_network_from_weights except: pass from utils.model_utils import str_to_dtype from utils.device_utils import clean_memory_on_device # Original load_video/load_images are still needed for Fun-Control / image loading from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device, load_images as hv_load_images, load_video as hv_load_video import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def parse_args() -> argparse.Namespace: """parse command line arguments""" parser = argparse.ArgumentParser(description="Wan 2.1 inference script") # WAN arguments parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") parser.add_argument( "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." ) parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path") parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path") parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16") parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path") # LoRA parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") parser.add_argument( "--save_merged_model", type=str, default=None, help="Save merged model to path. If specified, no inference will be performed.", ) # inference parser.add_argument("--prompt", type=str, required=True, help="prompt for generation (describe the continuation for extension)") parser.add_argument( "--negative_prompt", type=str, default=None, help="negative prompt for generation, use default negative prompt if not specified", ) parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") parser.add_argument("--video_length", type=int, default=None, help="Total video length (input+generated) for diffusion processing. Default depends on task/mode.") parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16") parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps") parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") parser.add_argument( "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." ) parser.add_argument( "--guidance_scale", type=float, default=5.0, help="Guidance scale for classifier free guidance. Default is 5.0.", ) # Modes (mutually exclusive) parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference (standard Wan V2V)") parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") parser.add_argument("--extend_video", type=str, default=None, help="path to video for extending it using initial frames") # Mode specific args parser.add_argument("--strength", type=float, default=0.75, help="Strength for video2video inference (0.0-1.0)") parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video or extension inference") parser.add_argument("--num_input_frames", type=int, default=4, help="Number of frames from start of --extend_video to use as input (min 1)") parser.add_argument("--extend_length", type=int, default=None, help="Number of frames to generate *after* the input frames for --extend_video. Default makes total length match task default (e.g., 81).") # Fun-Control argument (distinct from V2V/I2V/Extend) parser.add_argument( "--control_strength", type=float, default=1.0, help="Strength of control video influence for Fun-Control (1.0 = normal)", ) parser.add_argument( "--control_path", type=str, default=None, help="path to control video for inference with controlnet (Fun-Control model only). video file or directory with images", ) parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") parser.add_argument( "--cfg_skip_mode", type=str, default="none", choices=["early", "late", "middle", "early_late", "alternate", "none"], help="CFG skip mode. each mode skips different parts of the CFG. " " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)", ) parser.add_argument( "--cfg_apply_ratio", type=float, default=None, help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).", ) parser.add_argument( "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated" ) parser.add_argument( "--slg_scale", type=float, default=3.0, help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond", ) parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.") parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.") parser.add_argument( "--slg_mode", type=str, default=None, choices=["original", "uncond"], help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred", ) # Flow Matching parser.add_argument( "--flow_shift", type=float, default=None, help="Shift factor for flow matching schedulers. Default depends on task.", ) parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled") parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") parser.add_argument( "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" ) parser.add_argument( "--attn_mode", type=str, default="torch", choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"], help="attention mode", ) parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") parser.add_argument( "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" ) parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") parser.add_argument("--compile", action="store_true", help="Enable torch.compile") parser.add_argument( "--compile_args", nargs=4, metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], help="Torch.compile settings", ) args = parser.parse_args() assert (args.latent_path is None or len(args.latent_path) == 0) or ( args.output_type == "images" or args.output_type == "video" ), "latent_path is only supported for images or video output" # --- Mode Exclusivity Checks --- modes = [args.video_path, args.image_path, args.extend_video, args.control_path] num_modes_set = sum(1 for mode in modes if mode is not None) if num_modes_set > 1: active_modes = [] if args.video_path: active_modes.append("--video_path (V2V)") if args.image_path: active_modes.append("--image_path (I2V)") if args.extend_video: active_modes.append("--extend_video (Extend)") if args.control_path: active_modes.append("--control_path (Fun-Control)") # Allow Fun-Control + another mode conceptually, but the script logic needs adjustment if not (num_modes_set == 2 and args.control_path is not None): raise ValueError(f"Only one operation mode can be specified. Found: {', '.join(active_modes)}") # Special case: Fun-Control can technically be combined, but let's check task compatibility if args.control_path is not None and not WAN_CONFIGS[args.task].is_fun_control: raise ValueError("--control_path is provided, but the selected task does not support Fun-Control.") # --- Specific Mode Validations --- if args.extend_video is not None: if args.num_input_frames < 1: raise ValueError("--num_input_frames must be at least 1 for video extension.") if "t2v" in args.task: logger.warning("--extend_video provided, but task is t2v. Using I2V-like conditioning.") # We'll set video_length later based on num_input_frames and extend_length if args.image_path is not None: logger.warning("--image_path is provided. This is standard single-frame I2V.") if "t2v" in args.task: logger.warning("--image_path provided, but task is t2v. Using I2V conditioning.") if args.video_path is not None: logger.info("Running in V2V mode.") # V2V length is determined later if not specified if args.control_path is not None and not WAN_CONFIGS[args.task].is_fun_control: raise ValueError("--control_path is provided, but the selected task does not support Fun-Control.") return args def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None, is_extend_mode: bool = False) -> Tuple[int, float, int, bool]: """Return default values for each task Args: task: task name (t2v, t2i, i2v etc.) size: size of the video (width, height) is_extend_mode: whether we are in video extension mode Returns: Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip) """ width, height = size if size else (0, 0) # I2V and Extend mode share similar defaults is_i2v_like = "i2v" in task or is_extend_mode if "t2i" in task: return 50, 5.0, 1, False elif is_i2v_like: flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 return 40, flow_shift, 81, True # Default total length 81 else: # t2v or default return 50, 5.0, 81, False # Default total length 81 def setup_args(args: argparse.Namespace) -> argparse.Namespace: """Validate and set default values for optional arguments Args: args: command line arguments Returns: argparse.Namespace: updated arguments """ is_extend_mode = args.extend_video is not None # Get default values for the task default_infer_steps, default_flow_shift, default_video_length, _ = get_task_defaults(args.task, tuple(args.video_size), is_extend_mode) # Apply default values to unset arguments if args.infer_steps is None: args.infer_steps = default_infer_steps if args.flow_shift is None: args.flow_shift = default_flow_shift # --- Video Length Handling --- if is_extend_mode: if args.extend_length is None: # Calculate extend_length to reach the default total length args.extend_length = max(1, default_video_length - args.num_input_frames) logger.info(f"Defaulting --extend_length to {args.extend_length} to reach total length {default_video_length}") # Set the total video_length for processing args.video_length = args.num_input_frames + args.extend_length if args.video_length <= args.num_input_frames: raise ValueError(f"Total video length ({args.video_length}) must be greater than input frames ({args.num_input_frames}). Increase --extend_length.") elif args.video_length is None and args.video_path is None: # T2V, I2V (not extend) args.video_length = default_video_length elif args.video_length is None and args.video_path is not None: # V2V auto-detect pass # Delay setting default if V2V and length not specified elif args.video_length is not None: # User specified length pass # Force video_length to 1 for t2i tasks if "t2i" in task: assert args.video_length == 1, f"video_length should be 1 for task {args.task}" # parse slg_layers if args.slg_layers is not None: args.slg_layers = list(map(int, args.slg_layers.split(","))) return args def check_inputs(args: argparse.Namespace) -> Tuple[int, int, Optional[int]]: """Validate video size and potentially length (if not V2V auto-detect) Args: args: command line arguments Returns: Tuple[int, int, Optional[int]]: (height, width, video_length) """ height = args.video_size[0] width = args.video_size[1] size = f"{width}*{height}" is_extend_mode = args.extend_video is not None is_v2v_mode = args.video_path is not None # Check supported sizes unless it's V2V/Extend (input video dictates size) or FunControl if not is_v2v_mode and not is_extend_mode and not WAN_CONFIGS[args.task].is_fun_control: if size not in SUPPORTED_SIZES[args.task]: logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.") video_length = args.video_length # Might be None if V2V auto-detect if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") return height, width, video_length def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]: """calculate dimensions for the generation Args: video_size: video frame size (height, width) video_length: number of frames in the video being processed config: model configuration Returns: Tuple[Tuple[int, int, int, int], int]: ((channels, frames, height, width), seq_len) """ height, width = video_size frames = video_length # calculate latent space dimensions lat_f = (frames - 1) // config.vae_stride[0] + 1 lat_h = height // config.vae_stride[1] lat_w = width // config.vae_stride[2] # calculate sequence length seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f) return ((16, lat_f, lat_h, lat_w), seq_len) # Modified function (replace the original) def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE: """load VAE model with robust path handling Args: args: command line arguments config: model configuration device: device to use dtype: data type for the model Returns: WanVAE: loaded VAE model """ vae_override_path = args.vae vae_filename = config.vae_checkpoint # Get expected filename, e.g., "Wan2.1_VAE.pth" # Assume models are in 'wan' dir relative to script if not otherwise specified vae_base_dir = "wan" final_vae_path = None # 1. Check if args.vae is a valid *existing file path* if vae_override_path and isinstance(vae_override_path, str) and \ (vae_override_path.endswith(".pth") or vae_override_path.endswith(".safetensors")) and \ os.path.isfile(vae_override_path): final_vae_path = vae_override_path logger.info(f"Using VAE override path from --vae: {final_vae_path}") # 2. If override is invalid or not provided, construct default path if final_vae_path is None: constructed_path = os.path.join(vae_base_dir, vae_filename) if os.path.isfile(constructed_path): final_vae_path = constructed_path logger.info(f"Constructed default VAE path: {final_vae_path}") if vae_override_path: logger.warning(f"Ignoring potentially invalid --vae argument: {vae_override_path}") else: # 3. Fallback using ckpt_dir if provided and default construction failed if args.ckpt_dir: fallback_path = os.path.join(args.ckpt_dir, vae_filename) if os.path.isfile(fallback_path): final_vae_path = fallback_path logger.info(f"Using VAE path from --ckpt_dir fallback: {final_vae_path}") else: # If all attempts fail, raise error raise FileNotFoundError(f"Cannot find VAE. Checked override '{vae_override_path}', constructed '{constructed_path}', and fallback '{fallback_path}'") else: raise FileNotFoundError(f"Cannot find VAE. Checked override '{vae_override_path}' and constructed '{constructed_path}'. No --ckpt_dir provided for fallback.") # At this point, final_vae_path should be valid logger.info(f"Loading VAE model from final path: {final_vae_path}") cache_device = torch.device("cpu") if args.vae_cache_cpu else None vae = WanVAE(vae_path=final_vae_path, device=device, dtype=dtype, cache_device=cache_device) return vae def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel: """load text encoder (T5) model Args: args: command line arguments config: model configuration device: device to use Returns: T5EncoderModel: loaded text encoder model """ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint) tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer) text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=device, checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, weight_path=args.t5, fp8=args.fp8_t5, ) return text_encoder def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel: """load CLIP model (for I2V / Extend only) Args: args: command line arguments config: model configuration device: device to use Returns: CLIPModel: loaded CLIP model """ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint) tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer) clip = CLIPModel( dtype=config.clip_dtype, device=device, checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, weight_path=args.clip, ) return clip def load_dit_model( args: argparse.Namespace, config, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype] = None, is_i2v_like: bool = False, # Combined flag for I2V and Extend modes ) -> WanModel: """load DiT model Args: args: command line arguments config: model configuration device: device to use dit_dtype: data type for the model dit_weight_dtype: data type for the model weights. None for as-is is_i2v_like: I2V or Extend mode (might affect some model config details) Returns: WanModel: loaded DiT model """ loading_device = "cpu" if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled: loading_device = device loading_weight_dtype = dit_weight_dtype if args.fp8_scaled or args.lora_weight is not None: loading_weight_dtype = dit_dtype # load as-is # do not fp8 optimize because we will merge LoRA weights # Pass the is_i2v_like flag if the underlying loading function uses it model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, is_i2v_like) return model def merge_lora_weights(model: WanModel, args: argparse.Namespace, device: torch.device) -> None: """merge LoRA weights to the model Args: model: DiT model args: command line arguments device: device to use """ if args.lora_weight is None or len(args.lora_weight) == 0: return for i, lora_weight in enumerate(args.lora_weight): if args.lora_multiplier is not None and len(args.lora_multiplier) > i: lora_multiplier = args.lora_multiplier[i] else: lora_multiplier = 1.0 logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") weights_sd = load_file(lora_weight) # apply include/exclude patterns original_key_count = len(weights_sd.keys()) if args.include_patterns is not None and len(args.include_patterns) > i: include_pattern = args.include_patterns[i] regex_include = re.compile(include_pattern) weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") if args.exclude_patterns is not None and len(args.exclude_patterns) > i: original_key_count_ex = len(weights_sd.keys()) exclude_pattern = args.exclude_patterns[i] regex_exclude = re.compile(exclude_pattern) weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} logger.info( f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" ) if len(weights_sd) != original_key_count: remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) remaining_keys.sort() logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") if len(weights_sd) == 0: logger.warning(f"No keys left after filtering.") if args.lycoris: lycoris_net, _ = create_network_from_weights( multiplier=lora_multiplier, file=None, weights_sd=weights_sd, unet=model, text_encoder=None, vae=None, for_inference=True, ) lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) else: network = lora_wan.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) network.merge_to(None, model, weights_sd, device=device, non_blocking=True) synchronize_device(device) logger.info("LoRA weights loaded") # save model here before casting to dit_weight_dtype if args.save_merged_model: logger.info(f"Saving merged model to {args.save_merged_model}") mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory logger.info("Merged model saved") def optimize_model( model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype ) -> None: """optimize the model (FP8 conversion, device move etc.) Args: model: dit model args: command line arguments device: device to use dit_dtype: dtype for the model dit_weight_dtype: dtype for the model weights """ if args.fp8_scaled: # load state dict as-is and optimize to fp8 state_dict = model.state_dict() # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) info = model.load_state_dict(state_dict, strict=True, assign=True) logger.info(f"Loaded FP8 optimized weights: {info}") if args.blocks_to_swap == 0: model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) else: # simple cast to dit_dtype target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) target_device = None if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled logger.info(f"Convert model to {dit_weight_dtype}") target_dtype = dit_weight_dtype if args.blocks_to_swap == 0: logger.info(f"Move model to device: {device}") target_device = device model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations if args.compile: compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args logger.info( f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" ) torch._dynamo.config.cache_size_limit = 32 for i in range(len(model.blocks)): model.blocks[i] = torch.compile( model.blocks[i], backend=compile_backend, mode=compile_mode, dynamic=compile_dynamic.lower() in "true", fullgraph=compile_fullgraph.lower() in "true", ) if args.blocks_to_swap > 0: logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) model.move_to_device_except_swap_blocks(device) model.prepare_block_swap_before_forward() else: # make sure the model is on the right device model.to(device) model.eval().requires_grad_(False) clean_memory_on_device(device) def prepare_t2v_inputs( args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, vae: Optional[WanVAE] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: """Prepare inputs for T2V (including Fun-Control variation) Args: args: command line arguments config: model configuration accelerator: Accelerator instance device: device to use vae: VAE model, required only for Fun-Control Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: (noise, context, context_null, (arg_c, arg_null)) """ # Prepare inputs for T2V # calculate dimensions and sequence length height, width = args.video_size # T2V/FunControl length should be set by setup_args frames = args.video_length if frames is None: raise ValueError("video_length must be determined before calling prepare_t2v_inputs") (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, frames, config) target_shape = (16, lat_f, lat_h, lat_w) # Latent channel dim is 16 # configure negative prompt n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt # set seed seed = args.seed # Seed should be set in generate() if not args.cpu_noise: seed_g = torch.Generator(device=device) seed_g.manual_seed(seed) else: # ComfyUI compatible noise seed_g = torch.manual_seed(seed) # load text encoder text_encoder = load_text_encoder(args, config, device) text_encoder.model.to(device) # encode prompt with torch.no_grad(): if args.fp8_t5: with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): context = text_encoder([args.prompt], device) context_null = text_encoder([n_prompt], device) else: context = text_encoder([args.prompt], device) context_null = text_encoder([n_prompt], device) # free text encoder and clean memory del text_encoder clean_memory_on_device(device) # Fun-Control: encode control video to latent space y = None if config.is_fun_control and args.control_path: if vae is None: raise ValueError("VAE must be provided for Fun-Control input preparation.") logger.info(f"Encoding control video for Fun-Control") control_video = load_control_video(args.control_path, frames, height, width).to(device) vae.to_device(device) with accelerator.autocast(), torch.no_grad(): y = vae.encode([control_video])[0] # Encode video y = y * args.control_strength # Apply strength vae.to_device("cpu" if args.vae_cache_cpu else "cpu") # Move VAE back clean_memory_on_device(device) logger.info(f"Fun-Control conditioning 'y' shape: {y.shape}") # generate noise noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu") noise = noise.to(device) # prepare model input arguments arg_c = {"context": context, "seq_len": seq_len} arg_null = {"context": context_null, "seq_len": seq_len} if y is not None: # Add 'y' only if Fun-Control generated it arg_c["y"] = [y] arg_null["y"] = [y] return noise, context, context_null, (arg_c, arg_null) def load_video_frames(video_path: str, num_frames: int, target_reso: Tuple[int, int]) -> Tuple[List[np.ndarray], torch.Tensor]: """Load the first N frames from a video, resize, return numpy list and normalized tensor. Args: video_path (str): Path to the video file. num_frames (int): Number of frames to load from the start. target_reso (Tuple[int, int]): Target resolution (height, width). Returns: Tuple[List[np.ndarray], torch.Tensor]: - List of numpy arrays (frames) in HWC, RGB, uint8 format. - Tensor of shape [C, F, H, W], float32, range [0, 1]. """ logger.info(f"Loading first {num_frames} frames from {video_path}, target reso {target_reso}") target_h, target_w = target_reso cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Failed to open video file: {video_path}") # Get total frame count and check if enough frames exist total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames < num_frames: cap.release() raise ValueError(f"Video has only {total_frames} frames, but {num_frames} were requested for input.") # Read frames frames_np = [] for i in range(num_frames): ret, frame = cap.read() if not ret: logger.warning(f"Could only read {len(frames_np)} frames out of {num_frames} requested from {video_path}.") break # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Resize current_h, current_w = frame_rgb.shape[:2] interpolation = cv2.INTER_AREA if target_h * target_w < current_h * current_w else cv2.INTER_LANCZOS4 frame_resized = cv2.resize(frame_rgb, (target_w, target_h), interpolation=interpolation) frames_np.append(frame_resized) cap.release() if len(frames_np) != num_frames: raise RuntimeError(f"Failed to load the required {num_frames} frames.") # Convert list of numpy arrays to tensor [F, H, W, C] -> [C, F, H, W], range [0, 1] frames_tensor = torch.from_numpy(np.stack(frames_np, axis=0)).permute(0, 3, 1, 2).float() / 255.0 frames_tensor = frames_tensor.permute(1, 0, 2, 3) # [C, F, H, W] logger.info(f"Loaded {len(frames_np)} input frames. Tensor shape: {frames_tensor.shape}") # Return both the original numpy frames (for saving later) and the normalized tensor return frames_np, frames_tensor # Combined function for I2V and Extend modes def prepare_i2v_or_extend_inputs( args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, vae: WanVAE, input_frames_tensor: Optional[torch.Tensor] = None # Required for Extend mode ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: """Prepare inputs for I2V (single image) or Extend (multiple frames).""" if vae is None: raise ValueError("VAE must be provided for I2V/Extend input preparation.") is_extend_mode = input_frames_tensor is not None is_i2v_mode = args.image_path is not None # --- Get Dimensions and Frame Counts --- height, width = args.video_size frames = args.video_length # Total frames for diffusion process if frames is None: raise ValueError("video_length must be set before calling prepare_i2v_or_extend_inputs") num_input_frames = 0 if is_extend_mode: num_input_frames = args.num_input_frames if num_input_frames >= frames: raise ValueError(f"Number of input frames ({num_input_frames}) must be less than total video length ({frames})") elif is_i2v_mode: num_input_frames = 1 # --- Load Input Image(s) / Frames --- img_tensor_for_clip = None # Representative tensor for CLIP img_tensor_for_vae = None # Tensor containing all input frames/image for VAE if is_extend_mode: # Input frames tensor already provided (normalized [0,1]) img_tensor_for_vae = input_frames_tensor.to(device) # Use first frame for CLIP img_tensor_for_clip = img_tensor_for_vae[:, 0:1, :, :] # [C, 1, H, W] logger.info(f"Preparing inputs for Extend mode with {num_input_frames} input frames.") elif is_i2v_mode: # Load single image img = Image.open(args.image_path).convert("RGB") img_cv2 = np.array(img) interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC img_resized_np = cv2.resize(img_cv2, (width, height), interpolation=interpolation) # Normalized [0,1], shape [C, H, W] img_tensor_single = TF.to_tensor(img_resized_np).to(device) # Add frame dimension -> [C, 1, H, W] img_tensor_for_vae = img_tensor_single.unsqueeze(1) img_tensor_for_clip = img_tensor_for_vae logger.info("Preparing inputs for standard I2V mode.") else: raise ValueError("Neither extend_video nor image_path provided for I2V/Extend preparation.") # --- Optional End Frame --- has_end_image = args.end_image_path is not None end_img_tensor_vae = None # Normalized [-1, 1], shape [C, 1, H, W] if has_end_image: end_img = Image.open(args.end_image_path).convert("RGB") end_img_cv2 = np.array(end_img) interpolation_end = cv2.INTER_AREA if height < end_img_cv2.shape[0] else cv2.INTER_CUBIC end_img_resized_np = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation_end) # Normalized [0,1], shape [C, H, W] -> [C, 1, H, W] end_img_tensor_load = TF.to_tensor(end_img_resized_np).unsqueeze(1).to(device) end_img_tensor_vae = (end_img_tensor_load * 2.0 - 1.0) # Scale to [-1, 1] for VAE logger.info(f"Loaded end image: {args.end_image_path}") # --- Calculate Latent Dimensions --- lat_f = (frames - 1) // config.vae_stride[0] + 1 # Total latent frames lat_h = height // config.vae_stride[1] lat_w = width // config.vae_stride[2] # Latent frames corresponding to the input pixel frames lat_input_f = (num_input_frames - 1) // config.vae_stride[0] + 1 max_seq_len = math.ceil((lat_f + (1 if has_end_image else 0)) * lat_h * lat_w / (config.patch_size[1] * config.patch_size[2])) logger.info(f"Target latent shape: ({lat_f}, {lat_h}, {lat_w}), Input latent frames: {lat_input_f}, Seq len: {max_seq_len}") # --- Set Seed --- seed = args.seed seed_g = torch.Generator(device=device) if not args.cpu_noise else torch.manual_seed(seed) if not args.cpu_noise: seed_g.manual_seed(seed) # --- Generate Noise --- # Noise for the *entire* processing duration (including input frame slots) noise = torch.randn( 16, lat_f + (1 if has_end_image else 0), lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu" ).to(device) # --- Text Encoding --- n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt text_encoder = load_text_encoder(args, config, device) text_encoder.model.to(device) with torch.no_grad(): if args.fp8_t5: with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): context = text_encoder([args.prompt], device) context_null = text_encoder([n_prompt], device) else: context = text_encoder([args.prompt], device) context_null = text_encoder([n_prompt], device) del text_encoder clean_memory_on_device(device) # --- CLIP Encoding --- clip = load_clip_model(args, config, device) clip.model.to(device) with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): # Input needs to be [-1, 1], shape [C, 1, H, W] (or maybe [C, F, H, W] if model supports?) # Assuming visual encoder takes one frame: use the representative clip tensor clip_input = img_tensor_for_clip.sub_(0.5).div_(0.5) # Scale [0,1] -> [-1,1] clip_context = clip.visual([clip_input]) # Pass as list [tensor] del clip clean_memory_on_device(device) # --- VAE Encoding for Conditioning Tensor 'y' --- vae.to_device(device) y_latent_part = torch.zeros(config.latent_channels, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device, dtype=vae.dtype) with accelerator.autocast(), torch.no_grad(): # Encode the input frames/image (scale [0,1] -> [-1,1]) input_frames_vae = (img_tensor_for_vae * 2.0 - 1.0).to(dtype=vae.dtype) # [-1, 1] # Pad with zeros if needed to match VAE chunking? Assume encode handles variable length for now. encoded_input_latents = vae.encode([input_frames_vae])[0] # [C', F_in', H', W'] actual_encoded_input_f = encoded_input_latents.shape[1] if actual_encoded_input_f > lat_input_f: logger.warning(f"VAE encoded {actual_encoded_input_f} frames, expected {lat_input_f}. Truncating.") encoded_input_latents = encoded_input_latents[:, :lat_input_f, :, :] elif actual_encoded_input_f < lat_input_f: logger.warning(f"VAE encoded {actual_encoded_input_f} frames, expected {lat_input_f}. Padding needed for mask.") # This case shouldn't happen if lat_input_f calculation is correct, but handle defensively # Place encoded input latents into the full y tensor y_latent_part[:, :actual_encoded_input_f, :, :] = encoded_input_latents # Encode end image if present if has_end_image and end_img_tensor_vae is not None: encoded_end_latent = vae.encode([end_img_tensor_vae.to(dtype=vae.dtype)])[0] # [C', 1, H', W'] y_latent_part[:, -1:, :, :] = encoded_end_latent # Place at the end # --- Create Mask --- msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device, dtype=vae.dtype) msk[:, :lat_input_f, :, :] = 1 # Mask the input frames if has_end_image: msk[:, -1:, :, :] = 1 # Mask the end frame # --- Combine Mask and Latent Part for 'y' --- y = torch.cat([msk, y_latent_part], dim=0) # Shape [4+C', F_total', H', W'] logger.info(f"Constructed conditioning 'y' tensor shape: {y.shape}") # --- Fun-Control Integration (Optional, might need adjustment for Extend mode) --- if config.is_fun_control and args.control_path: logger.warning("Fun-Control with Extend mode is experimental. Control signal might conflict with input frames.") control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device) with accelerator.autocast(), torch.no_grad(): control_latent = vae.encode([control_video])[0] # Encode control video control_latent = control_latent * args.control_strength # Apply strength # How to combine? Replace y? Add? For now, let's assume control replaces the VAE part of y y = torch.cat([msk, control_latent], dim=0) # Overwrite latent part with control logger.info(f"Replaced latent part of 'y' with Fun-Control latent. New 'y' shape: {y.shape}") vae.to_device("cpu" if args.vae_cache_cpu else "cpu") # Move VAE back clean_memory_on_device(device) # --- Prepare Model Input Dictionaries --- arg_c = { "context": [context[0]], # Needs list format? Check model forward "clip_fea": clip_context, "seq_len": max_seq_len, "y": [y], # Pass conditioning tensor y } arg_null = { "context": context_null, "clip_fea": clip_context, "seq_len": max_seq_len, "y": [y], # Pass conditioning tensor y } return noise, context, context_null, y, (arg_c, arg_null) # --- V2V Helper Functions --- def load_video(video_path, start_frame=0, num_frames=None, bucket_reso=(256, 256)): """Load video frames and resize them to the target resolution for V2V. Args: video_path (str): Path to the video file start_frame (int): First frame to load (0-indexed) num_frames (int, optional): Number of frames to load. If None, load all frames from start_frame. bucket_reso (tuple): Target resolution (height, width) Returns: list: List of numpy arrays containing video frames in RGB format, resized. int: Actual number of frames loaded. """ logger.info(f"Loading video for V2V from {video_path}, target reso {bucket_reso}, frames {start_frame}-{start_frame+num_frames if num_frames else 'end'}") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Failed to open video file: {video_path}") # Get total frame count and FPS total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) logger.info(f"Input video has {total_frames} frames, {fps} FPS") # Calculate how many frames to load if num_frames is None: frames_to_load = total_frames - start_frame else: # Make sure we don't try to load more frames than exist frames_to_load = min(num_frames, total_frames - start_frame) if frames_to_load <= 0: cap.release() logger.warning(f"No frames to load (start_frame={start_frame}, num_frames={num_frames}, total_frames={total_frames})") return [], 0 # Skip to start frame if start_frame > 0: cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) # Read frames frames = [] target_h, target_w = bucket_reso for i in range(frames_to_load): ret, frame = cap.read() if not ret: logger.warning(f"Could only read {len(frames)} frames out of {frames_to_load} requested.") break # Convert from BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Resize the frame current_h, current_w = frame_rgb.shape[:2] interpolation = cv2.INTER_AREA if target_h * target_w < current_h * current_w else cv2.INTER_LANCZOS4 frame_resized = cv2.resize(frame_rgb, (target_w, target_h), interpolation=interpolation) frames.append(frame_resized) cap.release() actual_frames_loaded = len(frames) logger.info(f"Successfully loaded and resized {actual_frames_loaded} frames for V2V.") return frames, actual_frames_loaded def encode_video_to_latents(video_tensor: torch.Tensor, vae: WanVAE, device: torch.device, vae_dtype: torch.dtype, args: argparse.Namespace) -> torch.Tensor: """Encode video tensor to latent space using VAE for V2V. Args: video_tensor (torch.Tensor): Video tensor with shape [B, C, F, H, W], values in [-1, 1]. vae (WanVAE): VAE model instance. device (torch.device): Device to perform encoding on. vae_dtype (torch.dtype): Target dtype for the output latents. args (argparse.Namespace): Command line arguments (needed for vae_cache_cpu). Returns: torch.Tensor: Encoded latents with shape [B, C', F', H', W']. """ if vae is None: raise ValueError("VAE must be provided for video encoding.") logger.info(f"Encoding video tensor to latents: input shape {video_tensor.shape}") # Ensure VAE is on the correct device vae.to_device(device) # Prepare video tensor: move to device, ensure correct dtype video_tensor = video_tensor.to(device=device, dtype=vae.dtype) # Use VAE's dtype # WanVAE expects input as a list of [C, F, H, W] tensors (no batch dim) latents_list = [] batch_size = video_tensor.shape[0] for i in range(batch_size): video_single = video_tensor[i] # Shape [C, F, H, W] with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae.dtype): encoded_latent = vae.encode([video_single])[0] # Returns tensor [C', F', H', W'] latents_list.append(encoded_latent) # Stack results back into a batch latents = torch.stack(latents_list, dim=0) # Shape [B, C', F', H', W'] # Move VAE back to CPU (or cache device) vae_target_device = torch.device("cpu") if not args.vae_cache_cpu else torch.device("cpu") if args.vae_cache_cpu: logger.info("Moving VAE to CPU for caching.") else: logger.info("Moving VAE to CPU after encoding.") vae.to_device(vae_target_device) clean_memory_on_device(device) # Convert latents to the desired final dtype (e.g., bfloat16 for DiT) latents = latents.to(dtype=vae_dtype) # Use the target vae_dtype passed to function logger.info(f"Encoded video latents shape: {latents.shape}, dtype: {latents.dtype}") return latents def prepare_v2v_inputs(args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, video_latents: torch.Tensor): """Prepare inputs for Video2Video inference based on encoded video latents. Args: args (argparse.Namespace): Command line arguments. config: Model configuration. accelerator: Accelerator instance. device (torch.device): Device to use. video_latents (torch.Tensor): Encoded latent representation of input video [B, C', F', H', W']. Returns: Tuple containing noise, context, context_null, (arg_c, arg_null). """ # Get dimensions directly from the video latents if len(video_latents.shape) != 5: raise ValueError(f"Expected video_latents to have 5 dimensions [B, C, F, H, W], but got shape {video_latents.shape}") batch_size, latent_channels, lat_f, lat_h, lat_w = video_latents.shape if batch_size != 1: logger.warning(f"V2V input preparation currently assumes batch size 1, but got {batch_size}. Using first item.") video_latents = video_latents[0:1] # Keep batch dim # Calculate target shape and sequence length based on actual latent dimensions target_shape = video_latents.shape[1:] # [C', F', H', W'] (_, _, _), seq_len = calculate_dimensions((args.video_size[0], args.video_size[1]), args.video_length, config) # Use original args to get seq_len # (_, _, _), seq_len = calculate_dimensions((lat_h * config.vae_stride[1], lat_w * config.vae_stride[2]), (lat_f-1)*config.vae_stride[0]+1, config) # Recalculate seq_len from latent dims logger.info(f"V2V derived latent shape: {target_shape}, seq_len: {seq_len}") # Configure negative prompt n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt # Set seed (already set in generate(), just need generator) seed = args.seed if not args.cpu_noise: seed_g = torch.Generator(device=device) seed_g.manual_seed(seed) else: seed_g = torch.manual_seed(seed) # Load text encoder text_encoder = load_text_encoder(args, config, device) text_encoder.model.to(device) # Encode prompt with torch.no_grad(): if args.fp8_t5: with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): context = text_encoder([args.prompt], device) context_null = text_encoder([n_prompt], device) else: context = text_encoder([args.prompt], device) context_null = text_encoder([n_prompt], device) # Free text encoder and clean memory del text_encoder clean_memory_on_device(device) # Generate noise with the same shape as video_latents (including batch dimension) noise = torch.randn( video_latents.shape, # [B, C', F', H', W'] dtype=torch.float32, device=device if not args.cpu_noise else "cpu", generator=seed_g ) noise = noise.to(device) # Ensure noise is on the target device # Prepare model input arguments (context needs to match batch size of latents) arg_c = {"context": context, "seq_len": seq_len} arg_null = {"context": context_null, "seq_len": seq_len} # V2V does not use 'y' or 'clip_fea' in the standard Wan model case return noise, context, context_null, (arg_c, arg_null) # --- End V2V Helper Functions --- def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor: """load control video to pixel space for Fun-Control model Args: control_path: path to control video frames: number of frames in the video height: height of the video width: width of the video Returns: torch.Tensor: control video tensor, CFHW, range [-1, 1] """ logger.info(f"Load control video for Fun-Control from {control_path}") # Use the original helper from hv_generate_video for consistency if os.path.isfile(control_path): # Use hv_load_video which returns list of numpy arrays (HWC, 0-255) # NOTE: hv_load_video takes (W, H) for bucket_reso! video_frames_np = hv_load_video(control_path, 0, frames, bucket_reso=(width, height)) elif os.path.isdir(control_path): # Use hv_load_images which returns list of numpy arrays (HWC, 0-255) # NOTE: hv_load_images takes (W, H) for bucket_reso! video_frames_np = hv_load_images(control_path, frames, bucket_reso=(width, height)) else: raise FileNotFoundError(f"Control path not found: {control_path}") if not video_frames_np: raise ValueError(f"No frames loaded from control path: {control_path}") if len(video_frames_np) < frames: logger.warning(f"Control video has {len(video_frames_np)} frames, less than requested {frames}. Using available frames and repeating last.") # Repeat last frame to match length last_frame = video_frames_np[-1] video_frames_np.extend([last_frame] * (frames - len(video_frames_np))) # Stack and convert to tensor: F, H, W, C (0-255) -> F, C, H, W (-1 to 1) video_frames_np = np.stack(video_frames_np, axis=0) video_tensor = torch.from_numpy(video_frames_np).permute(0, 3, 1, 2).float() / 127.5 - 1.0 # Normalize to [-1, 1] # Permute to C, F, H, W video_tensor = video_tensor.permute(1, 0, 2, 3) logger.info(f"Loaded Fun-Control video tensor shape: {video_tensor.shape}") return video_tensor def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: """setup scheduler for sampling Args: args: command line arguments config: model configuration device: device to use Returns: Tuple[Any, torch.Tensor]: (scheduler, timesteps) """ if args.sample_solver == "unipc": scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) timesteps = scheduler.timesteps elif args.sample_solver == "dpm++": scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False ) sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) elif args.sample_solver == "vanilla": scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) scheduler.set_timesteps(args.infer_steps, device=device) timesteps = scheduler.timesteps # FlowMatchDiscreteScheduler does not support generator argument in step method org_step = scheduler.step def step_wrapper( model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, generator=None, # Add generator argument here ): # Call original step, ignoring generator if it doesn't accept it try: # Try calling with generator if the underlying class was updated return org_step(model_output, timestep, sample, return_dict=return_dict, generator=generator) except TypeError: # Fallback to calling without generator # logger.warning("Scheduler step does not support generator argument, proceeding without it.") # Reduce noise return org_step(model_output, timestep, sample, return_dict=return_dict) scheduler.step = step_wrapper else: raise NotImplementedError(f"Unsupported solver: {args.sample_solver}") logger.info(f"Using scheduler: {args.sample_solver}, timesteps shape: {timesteps.shape}") return scheduler, timesteps def run_sampling( model: WanModel, noise: torch.Tensor, # This might be pure noise (T2V/I2V/Extend) or mixed noise+latent (V2V) scheduler: Any, timesteps: torch.Tensor, # Might be a subset for V2V args: argparse.Namespace, inputs: Tuple[dict, dict], # (arg_c, arg_null) device: torch.device, seed_g: torch.Generator, accelerator: Accelerator, use_cpu_offload: bool = True, ) -> torch.Tensor: """run sampling loop (Denoising) Args: model: dit model noise: initial latent state (pure noise or mixed noise/video latent) scheduler: scheduler for sampling timesteps: time steps for sampling (can be subset for V2V) args: command line arguments inputs: model input dictionaries (arg_c, arg_null) containing context etc. device: device to use seed_g: random generator accelerator: Accelerator instance use_cpu_offload: Whether to offload tensors to CPU during processing Returns: torch.Tensor: generated latent """ arg_c, arg_null = inputs # Ensure inputs (context, y, etc.) are correctly formatted (e.g., lists if model expects list input) # Example: ensure context is list [tensor] if model expects list if isinstance(arg_c.get("context"), torch.Tensor): arg_c["context"] = [arg_c["context"]] if isinstance(arg_null.get("context"), torch.Tensor): arg_null["context"] = [arg_null["context"]] # Similar checks/conversions for other keys like 'y' if needed based on WanModel.forward signature latent = noise # Initialize latent state [B, C, F, H, W] latent_storage_device = device if not use_cpu_offload else "cpu" latent = latent.to(latent_storage_device) # Move initial state to storage device # cfg skip logic apply_cfg_array = [] num_timesteps = len(timesteps) if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None: # Calculate thresholds based on cfg_apply_ratio apply_steps = int(num_timesteps * args.cfg_apply_ratio) if args.cfg_skip_mode == "early": start_index = num_timesteps - apply_steps; end_index = num_timesteps elif args.cfg_skip_mode == "late": start_index = 0; end_index = apply_steps elif args.cfg_skip_mode == "early_late": start_index = (num_timesteps - apply_steps) // 2; end_index = start_index + apply_steps elif args.cfg_skip_mode == "middle": skip_steps = num_timesteps - apply_steps middle_start = (num_timesteps - skip_steps) // 2; middle_end = middle_start + skip_steps else: # Includes "alternate" - handled inside loop start_index = 0; end_index = num_timesteps # Default range for alternate w = 0.0 # For alternate mode for step_idx in range(num_timesteps): apply = True # Default if args.cfg_skip_mode == "alternate": w += args.cfg_apply_ratio; apply = w >= 1.0 if apply: w -= 1.0 elif args.cfg_skip_mode == "middle": apply = not (step_idx >= middle_start and step_idx < middle_end) elif args.cfg_skip_mode != "none": # early, late, early_late apply = step_idx >= start_index and step_idx < end_index apply_cfg_array.append(apply) pattern = ["A" if apply else "S" for apply in apply_cfg_array] pattern = "".join(pattern) logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, steps: {num_timesteps}, pattern: {pattern}") else: # Apply CFG on all steps apply_cfg_array = [True] * num_timesteps # SLG (Skip Layer Guidance) setup apply_slg_global = args.slg_layers is not None and args.slg_mode is not None slg_start_step = int(args.slg_start * num_timesteps) slg_end_step = int(args.slg_end * num_timesteps) logger.info(f"Starting sampling loop for {num_timesteps} steps.") for i, t in enumerate(tqdm(timesteps)): # Prepare input for the model (move latent to compute device) # Latent should be [B, C, F, H, W] # Model expects latent input 'x' as list: [tensor] latent_on_device = latent.to(device) latent_model_input_list = [latent_on_device] # Wrap in list timestep = torch.stack([t]).to(device) # Ensure timestep is a tensor on device with accelerator.autocast(), torch.no_grad(): # 1. Predict conditional noise estimate noise_pred_cond = model(x=latent_model_input_list, t=timestep, **arg_c)[0] noise_pred_cond = noise_pred_cond.to(latent_storage_device) # 2. Predict unconditional noise estimate (potentially with SLG) apply_cfg = apply_cfg_array[i] if apply_cfg: apply_slg_step = apply_slg_global and (i >= slg_start_step and i < slg_end_step) slg_indices_for_call = args.slg_layers if apply_slg_step else None uncond_input_args = arg_null if apply_slg_step and args.slg_mode == "original": # Standard uncond prediction first noise_pred_uncond = model(x=latent_model_input_list, t=timestep, **uncond_input_args)[0].to(latent_storage_device) # SLG prediction (skipping layers in uncond) skip_layer_out = model(x=latent_model_input_list, t=timestep, skip_block_indices=slg_indices_for_call, **uncond_input_args)[0].to(latent_storage_device) # Combine: scaled = uncond + scale * (cond - uncond) + slg_scale * (cond - skip) noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out) elif apply_slg_step and args.slg_mode == "uncond": # SLG prediction (skipping layers in uncond) replaces standard uncond noise_pred_uncond = model(x=latent_model_input_list, t=timestep, skip_block_indices=slg_indices_for_call, **uncond_input_args)[0].to(latent_storage_device) # Combine: scaled = slg_uncond + scale * (cond - slg_uncond) noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: # Regular CFG (no SLG or SLG not active this step) noise_pred_uncond = model(x=latent_model_input_list, t=timestep, **uncond_input_args)[0].to(latent_storage_device) # Combine: scaled = uncond + scale * (cond - uncond) noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: # CFG is skipped for this step, use conditional prediction directly noise_pred = noise_pred_cond # 3. Compute previous sample state with the scheduler # Scheduler expects noise_pred [B, C, F, H, W] and latent [B, C, F, H, W] scheduler_output = scheduler.step( noise_pred.to(device), # Ensure noise_pred is on compute device t, latent_on_device, # Pass the tensor directly return_dict=False, generator=seed_g # Pass generator ) prev_latent = scheduler_output[0] # Get the new latent state [B, C, F, H, W] # 4. Update latent state (move back to storage device) latent = prev_latent.to(latent_storage_device) # Return the final denoised latent (should be on storage device) logger.info("Sampling loop finished.") return latent def generate(args: argparse.Namespace) -> Tuple[Optional[torch.Tensor], Optional[List[np.ndarray]]]: """main function for generation pipeline (T2V, I2V, V2V, Extend) Args: args: command line arguments Returns: Tuple[Optional[torch.Tensor], Optional[List[np.ndarray]]]: - generated latent tensor [B, C, F, H, W], or None if error/skipped. - list of original input frames (numpy HWC RGB uint8) if in Extend mode, else None. """ device = torch.device(args.device) cfg = WAN_CONFIGS[args.task] # --- Determine Mode --- is_extend_mode = args.extend_video is not None is_i2v_mode = args.image_path is not None and not is_extend_mode is_v2v_mode = args.video_path is not None is_fun_control = args.control_path is not None and cfg.is_fun_control # Can overlap is_t2v_mode = not is_extend_mode and not is_i2v_mode and not is_v2v_mode and not is_fun_control mode_str = ("Extend" if is_extend_mode else "I2V" if is_i2v_mode else "V2V" if is_v2v_mode else "T2V" + ("+FunControl" if is_fun_control else "")) if is_fun_control and not is_t2v_mode: # If funcontrol combined with other modes mode_str += "+FunControl" logger.info(f"Running in {mode_str} mode") # --- Data Types --- dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16 if dit_dtype.itemsize == 1: dit_dtype = torch.bfloat16 if args.fp8_scaled: raise ValueError("Cannot use --fp8_scaled with pre-quantized FP8 weights.") dit_weight_dtype = None elif args.fp8_scaled: dit_weight_dtype = None elif args.fp8: dit_weight_dtype = torch.float8_e4m3fn else: dit_weight_dtype = dit_dtype vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else (torch.bfloat16 if dit_dtype == torch.bfloat16 else torch.float16) logger.info( f"Using device: {device}, DiT compute: {dit_dtype}, DiT weight: {dit_weight_dtype or 'Mixed (FP8 Scaled)' if args.fp8_scaled else dit_dtype}, VAE: {vae_dtype}, T5 FP8: {args.fp8_t5}" ) # --- Accelerator --- mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) # --- Seed --- seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) args.seed = seed logger.info(f"Using seed: {seed}") # --- Load VAE (if needed for input processing) --- vae = None needs_vae_early = is_extend_mode or is_i2v_mode or is_v2v_mode or is_fun_control if needs_vae_early: vae = load_vae(args, cfg, device, vae_dtype) # --- Prepare Inputs --- noise = None context = None context_null = None inputs = None video_latents = None # For V2V mixing original_input_frames_np = None # For Extend mode saving if is_extend_mode: # 1. Load initial frames (numpy list and normalized tensor) original_input_frames_np, input_frames_tensor = load_video_frames( args.extend_video, args.num_input_frames, tuple(args.video_size) ) # 2. Prepare inputs using the loaded frames tensor noise, context, context_null, _, inputs = prepare_i2v_or_extend_inputs( args, cfg, accelerator, device, vae, input_frames_tensor=input_frames_tensor ) del input_frames_tensor # Free memory clean_memory_on_device(device) elif is_i2v_mode: # Prepare I2V inputs (single image) noise, context, context_null, _, inputs = prepare_i2v_or_extend_inputs( args, cfg, accelerator, device, vae ) elif is_v2v_mode: # 1. Load and prepare video video_frames_np, actual_frames_loaded = load_video( args.video_path, start_frame=0, num_frames=args.video_length, bucket_reso=tuple(args.video_size) ) if actual_frames_loaded == 0: raise ValueError(f"Could not load frames from video: {args.video_path}") if args.video_length is None or actual_frames_loaded < args.video_length: logger.info(f"Updating video_length based on loaded V2V frames: {actual_frames_loaded}") args.video_length = actual_frames_loaded height, width, video_length = check_inputs(args) # Re-check # Convert frames np [F,H,W,C] uint8 -> tensor [1,C,F,H,W] float32 [-1, 1] video_tensor = torch.from_numpy(np.stack(video_frames_np, axis=0)) video_tensor = video_tensor.permute(0, 3, 1, 2).float() # F,C,H,W video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0) # 1,C,F,H,W video_tensor = video_tensor / 127.5 - 1.0 # Normalize to [-1, 1] # 2. Encode video to latents (pass vae_dtype for DiT compatibility) video_latents = encode_video_to_latents(video_tensor, vae, device, vae_dtype, args) del video_tensor, video_frames_np clean_memory_on_device(device) # 3. Prepare V2V inputs (noise, context, etc.) noise, context, context_null, inputs = prepare_v2v_inputs(args, cfg, accelerator, device, video_latents) elif is_t2v_mode or is_fun_control: # Should handle T2V+FunControl here # Prepare T2V inputs (passes VAE if is_fun_control) if args.video_length is None: raise ValueError("video_length must be specified for T2V/Fun-Control.") noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae if is_fun_control else None) # At this point, VAE should be on CPU/cache unless still needed for decoding # --- Load DiT Model --- is_i2v_like = is_i2v_mode or is_extend_mode model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v_like) # --- Merge LoRA --- if args.lora_weight is not None and len(args.lora_weight) > 0: merge_lora_weights(model, args, device) if args.save_merged_model: logger.info("Merged model saved. Exiting without generation.") return None, None # --- Optimize Model --- optimize_model(model, args, device, dit_dtype, dit_weight_dtype) # --- Setup Scheduler & Timesteps --- scheduler, timesteps = setup_scheduler(args, cfg, device) # --- Prepare for Sampling --- seed_g = torch.Generator(device=device) seed_g.manual_seed(seed) latent = noise # Start with noise (correctly shaped for T2V/I2V/Extend) # --- V2V Strength Adjustment --- if is_v2v_mode and args.strength < 1.0: if video_latents is None: raise RuntimeError("video_latents not available for V2V strength.") num_inference_steps = max(1, int(args.infer_steps * args.strength)) logger.info(f"V2V Strength: {args.strength}, adjusting inference steps to {num_inference_steps}") t_start_idx = len(timesteps) - num_inference_steps if t_start_idx < 0: t_start_idx = 0 t_start = timesteps[t_start_idx] # Use scheduler.add_noise for proper mixing video_latents = video_latents.to(device=noise.device, dtype=noise.dtype) latent = scheduler.add_noise(video_latents, noise, t_start.unsqueeze(0).expand(noise.shape[0])) # Add noise based on start time latent = latent.to(noise.dtype) # Ensure correct dtype after add_noise logger.info(f"Mixed noise and video latents using scheduler.add_noise at timestep {t_start.item():.1f}") timesteps = timesteps[t_start_idx:] # Use subset of timesteps logger.info(f"Using last {len(timesteps)} timesteps for V2V sampling.") else: logger.info(f"Using full {len(timesteps)} timesteps for sampling.") # Latent remains the initial noise (already handles I2V/Extend via 'y' conditioning) # --- Run Sampling Loop --- logger.info("Starting denoising sampling loop...") final_latent = run_sampling( model, latent, scheduler, timesteps, args, inputs, device, seed_g, accelerator, use_cpu_offload=(args.blocks_to_swap > 0) ) # --- Cleanup --- del model, scheduler, context, context_null, inputs if video_latents is not None: del video_latents synchronize_device(device) if args.blocks_to_swap > 0: logger.info("Waiting 5 seconds for block swap cleanup...") time.sleep(5) gc.collect() clean_memory_on_device(device) # Store VAE instance for decoding args._vae = vae # Return latent [B, C, F, H, W] and original frames if extending if len(final_latent.shape) == 4: final_latent = final_latent.unsqueeze(0) return final_latent, original_input_frames_np def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor: """decode latent tensor to video frames Args: latent: latent tensor [B, C, F, H, W] args: command line arguments (contains _vae instance) cfg: model configuration Returns: torch.Tensor: decoded video tensor [B, C, F, H, W], range [0, 1], on CPU """ device = torch.device(args.device) vae = None if hasattr(args, "_vae") and args._vae is not None: vae = args._vae logger.info("Using VAE instance from generation pipeline for decoding.") else: logger.info("Loading VAE for decoding...") vae_dtype_decode = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16 # Default bfloat16 if not specified vae = load_vae(args, cfg, device, vae_dtype_decode) args._vae = vae vae.to_device(device) logger.info(f"Decoding video from latents: shape {latent.shape}, dtype {latent.dtype}") latent_decode = latent.to(device=device, dtype=vae.dtype) videos = None with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): # Assuming vae.decode handles batch tensor [B, C, F, H, W] and returns list of [C, F, H, W] decoded_list = vae.decode(latent_decode) if decoded_list and len(decoded_list) > 0: videos = torch.stack(decoded_list, dim=0) # Stack list back into batch: B, C, F, H, W else: raise RuntimeError("VAE decoding failed or returned empty list.") vae.to_device("cpu" if args.vae_cache_cpu else "cpu") # Move back VAE clean_memory_on_device(device) logger.info(f"Decoded video shape: {videos.shape}") # Post-processing: scale [-1, 1] -> [0, 1], clamp, move to CPU float32 videos = (videos + 1.0) / 2.0 videos = torch.clamp(videos, 0.0, 1.0) video_final = videos.cpu().to(torch.float32) # Apply trim tail frames *after* decoding if args.trim_tail_frames > 0: logger.info(f"Trimming last {args.trim_tail_frames} frames from decoded video.") video_final = video_final[:, :, : -args.trim_tail_frames, :, :] logger.info(f"Decoding complete. Final video tensor shape: {video_final.shape}") return video_final def save_output( video_tensor: torch.Tensor, # Full decoded video [B, C, F, H, W], range [0, 1] args: argparse.Namespace, original_base_names: Optional[List[str]] = None, latent_to_save: Optional[torch.Tensor] = None, # Full latent [B, C, F, H, W] original_input_frames_np: Optional[List[np.ndarray]] = None # For Extend mode ) -> None: """save output video, images, or latent, handling concatenation for Extend mode""" save_path = args.save_path os.makedirs(save_path, exist_ok=True) time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") seed = args.seed is_extend_mode = original_input_frames_np is not None # --- Determine Final Video Tensor for Saving --- video_to_save = video_tensor # Default: save the full decoded tensor final_video_length = video_tensor.shape[2] final_height = video_tensor.shape[3] final_width = video_tensor.shape[4] if is_extend_mode: logger.info("Processing output for Extend mode: concatenating original frames with generated frames.") num_original_frames = len(original_input_frames_np) # 1. Prepare original frames tensor: list[HWC uint8] -> tensor[B, C, N, H, W] float32 [0,1] original_frames_np_stacked = np.stack(original_input_frames_np, axis=0) # [N, H, W, C] original_frames_tensor = torch.from_numpy(original_frames_np_stacked).permute(0, 3, 1, 2).float() / 255.0 # [N, C, H, W] original_frames_tensor = original_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0) # [1, C, N, H, W] original_frames_tensor = original_frames_tensor.to(video_tensor.device, dtype=video_tensor.dtype) # Match decoded tensor attributes # 2. Extract the generated part from the decoded tensor # The decoded tensor includes reconstructed input frames + generated frames # We only want the part *after* the input frames. if video_tensor.shape[2] <= num_original_frames: logger.error(f"Decoded video length ({video_tensor.shape[2]}) is not longer than original frames ({num_original_frames}). Cannot extract generated part.") # Fallback to saving the full decoded video? Or raise error? # Let's save the full decoded video for inspection logger.warning("Saving the full decoded video instead of concatenating.") else: generated_part_tensor = video_tensor[:, :, num_original_frames:, :, :] # [B, C, M, H, W] # 3. Concatenate original pixel tensor + generated pixel tensor video_to_save = torch.cat((original_frames_tensor, generated_part_tensor), dim=2) # Concat along Frame dimension final_video_length = video_to_save.shape[2] # Update final length logger.info(f"Concatenated original {num_original_frames} frames with generated {generated_part_tensor.shape[2]} frames. Final shape: {video_to_save.shape}") # --- Determine Base Filename --- base_name = f"{time_flag}_{seed}" if original_base_names: base_name += f"_{original_base_names[0]}" # Use original name if from latent elif args.extend_video: input_video_name = os.path.splitext(os.path.basename(args.extend_video))[0] base_name += f"_ext_{input_video_name}" elif args.image_path: input_image_name = os.path.splitext(os.path.basename(args.image_path))[0] base_name += f"_i2v_{input_image_name}" elif args.video_path: input_video_name = os.path.splitext(os.path.basename(args.video_path))[0] base_name += f"_v2v_{input_video_name}" # Add prompt hint? Might be too long # prompt_hint = "".join(filter(str.isalnum, args.prompt))[:20] # base_name += f"_{prompt_hint}" # --- Save Latent --- if (args.output_type == "latent" or args.output_type == "both") and latent_to_save is not None: latent_path = os.path.join(save_path, f"{base_name}_latent.safetensors") logger.info(f"Saving latent tensor shape: {latent_to_save.shape}") # Save the full latent metadata = {} if not args.no_metadata: # Get metadata from final saved video dimensions metadata = { "prompt": f"{args.prompt}", "negative_prompt": f"{args.negative_prompt or ''}", "seeds": f"{seed}", "height": f"{final_height}", "width": f"{final_width}", "video_length": f"{final_video_length}", # Length of the *saved* video/latent "infer_steps": f"{args.infer_steps}", "guidance_scale": f"{args.guidance_scale}", "flow_shift": f"{args.flow_shift}", "task": f"{args.task}", "dit_model": f"{args.dit or os.path.join(args.ckpt_dir, cfg.dit_checkpoint) if args.ckpt_dir else 'N/A'}", "vae_model": f"{args.vae or os.path.join(args.ckpt_dir, cfg.vae_checkpoint) if args.ckpt_dir else 'N/A'}", "mode": ("Extend" if is_extend_mode else "I2V" if args.image_path else "V2V" if args.video_path else "T2V"), } if is_extend_mode: metadata["extend_video"] = f"{os.path.basename(args.extend_video)}" metadata["num_input_frames"] = f"{args.num_input_frames}" metadata["extend_length"] = f"{args.extend_length}" # Generated part length metadata["total_processed_length"] = f"{latent_to_save.shape[2]}" # Latent length # Add other mode details... (V2V strength, I2V image, etc.) if args.video_path: metadata["v2v_strength"] = f"{args.strength}" if args.image_path: metadata["i2v_image"] = f"{os.path.basename(args.image_path)}" if args.end_image_path: metadata["end_image"] = f"{os.path.basename(args.end_image_path)}" if args.control_path: metadata["funcontrol_video"] = f"{os.path.basename(args.control_path)}" if args.lora_weight: metadata["lora_weights"] = ", ".join([os.path.basename(p) for p in args.lora_weight]) metadata["lora_multipliers"] = ", ".join(map(str, args.lora_multiplier)) sd = {"latent": latent_to_save.cpu()} try: save_file(sd, latent_path, metadata=metadata) logger.info(f"Latent saved to: {latent_path}") except Exception as e: logger.error(f"Failed to save latent file: {e}") # --- Save Video or Images --- if args.output_type == "video" or args.output_type == "both": video_path = os.path.join(save_path, f"{base_name}.mp4") # save_videos_grid expects [B, T, H, W, C], input is [B, C, T, H, W] range [0, 1] try: # Ensure tensor is on CPU for saving function save_videos_grid(video_to_save.cpu(), video_path, fps=args.fps, rescale=False) logger.info(f"Video saved to: {video_path}") except Exception as e: logger.error(f"Failed to save video file: {e}") logger.error(f"Video tensor info: shape={video_to_save.shape}, dtype={video_to_save.dtype}, min={video_to_save.min()}, max={video_to_save.max()}") elif args.output_type == "images": image_save_dir = os.path.join(save_path, base_name) os.makedirs(image_save_dir, exist_ok=True) # save_images_grid expects [B, T, H, W, C] try: save_images_grid(video_to_save.cpu(), image_save_dir, "frame", rescale=False, save_individually=True) logger.info(f"Image frames saved to directory: {image_save_dir}") except Exception as e: logger.error(f"Failed to save image files: {e}") def main(): # --- Argument Parsing & Setup --- args = parse_args() latents_mode = args.latent_path is not None and len(args.latent_path) > 0 device_str = args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu") args.device = torch.device(device_str) logger.info(f"Using device: {args.device}") generated_latent = None original_input_frames_np = None # Store original frames for extend mode cfg = WAN_CONFIGS[args.task] height, width, video_length = None, None, None original_base_names = None # For naming output when loading latents if not latents_mode: # --- Generation Mode --- logger.info("Running in Generation Mode") args = setup_args(args) # Sets defaults, calculates video_length for extend mode height, width, video_length = check_inputs(args) # Validate final dimensions args.video_size = [height, width] args.video_length = video_length # Ensure video_length is stored in args for processing mode_str = ("Extend" if args.extend_video else "I2V" if args.image_path else "V2V" if args.video_path else "T2V" + ("+FunControl" if args.control_path else "")) if args.control_path and not (args.extend_video or args.image_path or args.video_path): pass # Already handled above elif args.control_path: mode_str += "+FunControl" logger.info(f"Mode: {mode_str}") logger.info( f"Settings: video size: {height}x{width}, processed length: {video_length} frames, fps: {args.fps}, " f"infer_steps: {args.infer_steps}, guidance: {args.guidance_scale}, flow_shift: {args.flow_shift}" ) if args.extend_video: logger.info(f" Extend details: Input video: {args.extend_video}, Input frames: {args.num_input_frames}, Generated frames: {args.extend_length}") # Core generation pipeline - returns latent and potentially original frames generated_latent, original_input_frames_np = generate(args) if args.save_merged_model: logger.info("Exiting after saving merged model.") return if generated_latent is None: logger.error("Generation failed or was skipped, exiting.") return # Get dimensions from the *generated latent* for logging/metadata consistency _, _, lat_f, lat_h, lat_w = generated_latent.shape processed_pixel_height = lat_h * cfg.vae_stride[1] processed_pixel_width = lat_w * cfg.vae_stride[2] processed_pixel_frames = (lat_f - 1) * cfg.vae_stride[0] + 1 logger.info(f"Generation complete. Processed latent shape: {generated_latent.shape} -> Approx Pixel Video: {processed_pixel_height}x{processed_pixel_width}@{processed_pixel_frames}") # Note: Final saved dimensions might differ slightly due to concatenation in Extend mode else: # --- Latents Mode --- logger.info("Running in Latent Loading Mode") original_base_names = [] latents_list = [] seeds = [] metadata = {} if len(args.latent_path) > 1: logger.warning("Loading multiple latent files is not fully supported. Using first file's info.") latent_path = args.latent_path[0] original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) loaded_latent = None seed = args.seed if args.seed is not None else 0 try: if os.path.splitext(latent_path)[1] != ".safetensors": logger.warning("Loading non-safetensors latent file. Metadata might be missing.") loaded_latent = torch.load(latent_path, map_location="cpu") if isinstance(loaded_latent, dict): if "latent" in loaded_latent: loaded_latent = loaded_latent["latent"] elif "state_dict" in loaded_latent: raise ValueError("Loaded file appears to be a model checkpoint.") else: first_key = next(iter(loaded_latent)); loaded_latent = loaded_latent[first_key] else: loaded_latent = load_file(latent_path, device="cpu")["latent"] with safe_open(latent_path, framework="pt", device="cpu") as f: metadata = f.metadata() or {} logger.info(f"Loaded metadata: {metadata}") # Restore args from metadata if available if "seeds" in metadata: seed = int(metadata["seeds"]) if "prompt" in metadata: args.prompt = metadata["prompt"] if "negative_prompt" in metadata: args.negative_prompt = metadata["negative_prompt"] # Use metadata dimensions if available, otherwise infer later if "height" in metadata and "width" in metadata: height = int(metadata["height"]); width = int(metadata["width"]) args.video_size = [height, width] if "video_length" in metadata: # This is the length of the *saved* video/latent video_length = int(metadata["video_length"]) args.video_length = video_length # Store the length of the latent data # Restore other relevant args... if "guidance_scale" in metadata: args.guidance_scale = float(metadata["guidance_scale"]) if "infer_steps" in metadata: args.infer_steps = int(metadata["infer_steps"]) if "flow_shift" in metadata: args.flow_shift = float(metadata["flow_shift"]) if "mode" in metadata and metadata["mode"] == "Extend": if "num_input_frames" in metadata: args.num_input_frames = int(metadata["num_input_frames"]) # Cannot reliably get original frames from latent, so concatenation won't work right seeds.append(seed) latents_list.append(loaded_latent) logger.info(f"Loaded latent from {latent_path}. Shape: {loaded_latent.shape}, dtype: {loaded_latent.dtype}") except Exception as e: logger.error(f"Failed to load latent file {latent_path}: {e}") return if not latents_list: logger.error("No latent tensors loaded."); return generated_latent = torch.stack(latents_list, dim=0) # [B, C, F, H, W] if len(generated_latent.shape) != 5: raise ValueError(f"Loaded latent shape error: {generated_latent.shape}") args.seed = seeds[0] # Infer pixel dimensions from latent if not fully set by metadata if height is None or width is None or video_length is None: logger.warning("Dimensions not fully found in metadata, inferring from latent shape.") _, _, lat_f, lat_h, lat_w = generated_latent.shape height = lat_h * cfg.vae_stride[1]; width = lat_w * cfg.vae_stride[2] video_length = (lat_f - 1) * cfg.vae_stride[0] + 1 # This is the length corresponding to the latent logger.info(f"Inferred pixel dimensions from latent: {height}x{width}@{video_length}") args.video_size = [height, width]; args.video_length = video_length # --- Decode and Save --- if generated_latent is not None: # Decode latent to video tensor [B, C, F, H, W], range [0, 1] # Note: args.video_length might be different from latent's frame dim if trimmed during decode decoded_video = decode_latent(generated_latent, args, cfg) # Save output (handles Extend mode concatenation inside) save_output( decoded_video, args, original_base_names=original_base_names, latent_to_save=generated_latent if (args.output_type in ["latent", "both"]) else None, original_input_frames_np=original_input_frames_np # Pass original frames if in Extend mode ) else: logger.error("No latent available for decoding and saving.") logger.info("Done!") if __name__ == "__main__": main()