Spaces:
Running
Running
import argparse | |
from datetime import datetime | |
import gc | |
import random | |
import os | |
import re | |
import time | |
import math | |
import copy | |
from types import ModuleType, SimpleNamespace | |
from typing import Tuple, Optional, List, Union, Any, Dict | |
import torch | |
import accelerate | |
from accelerate import Accelerator | |
from safetensors.torch import load_file, save_file | |
from safetensors import safe_open | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import torchvision.transforms.functional as TF | |
from tqdm import tqdm | |
from networks import lora_wan | |
from utils.safetensors_utils import mem_eff_save_file, load_safetensors | |
from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES | |
import wan | |
from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype | |
from wan.modules.vae import WanVAE | |
from wan.modules.t5 import T5EncoderModel | |
from wan.modules.clip import CLIPModel | |
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler | |
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps | |
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
try: | |
from lycoris.kohya import create_network_from_weights | |
except: | |
pass | |
from utils.model_utils import str_to_dtype | |
from utils.device_utils import clean_memory_on_device | |
from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device | |
from dataset.image_video_dataset import load_video | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class GenerationSettings: | |
def __init__( | |
self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype | |
): | |
self.device = device | |
self.cfg = cfg | |
self.dit_dtype = dit_dtype | |
self.dit_weight_dtype = dit_weight_dtype | |
self.vae_dtype = vae_dtype | |
def parse_args() -> argparse.Namespace: | |
"""parse command line arguments""" | |
parser = argparse.ArgumentParser(description="Wan 2.1 inference script") | |
# WAN arguments | |
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") | |
parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") | |
parser.add_argument( | |
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." | |
) | |
parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path") | |
parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path") | |
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16") | |
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") | |
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") | |
parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path") | |
# LoRA | |
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") | |
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") | |
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") | |
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") | |
parser.add_argument( | |
"--save_merged_model", | |
type=str, | |
default=None, | |
help="Save merged model to path. If specified, no inference will be performed.", | |
) | |
# inference | |
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") | |
parser.add_argument( | |
"--negative_prompt", | |
type=str, | |
default=None, | |
help="negative prompt for generation, use default negative prompt if not specified", | |
) | |
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") | |
parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task") | |
parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16") | |
parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps") | |
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") | |
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") | |
parser.add_argument( | |
"--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." | |
) | |
parser.add_argument( | |
"--guidance_scale", | |
type=float, | |
default=5.0, | |
help="Guidance scale for classifier free guidance. Default is 5.0.", | |
) | |
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") | |
parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") | |
parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") | |
parser.add_argument( | |
"--control_path", | |
type=str, | |
default=None, | |
help="path to control video for inference with controlnet. video file or directory with images", | |
) | |
parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") | |
parser.add_argument( | |
"--cfg_skip_mode", | |
type=str, | |
default="none", | |
choices=["early", "late", "middle", "early_late", "alternate", "none"], | |
help="CFG skip mode. each mode skips different parts of the CFG. " | |
" early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)", | |
) | |
parser.add_argument( | |
"--cfg_apply_ratio", | |
type=float, | |
default=None, | |
help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).", | |
) | |
parser.add_argument( | |
"--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated" | |
) | |
parser.add_argument( | |
"--slg_scale", | |
type=float, | |
default=3.0, | |
help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond", | |
) | |
parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.") | |
parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.") | |
parser.add_argument( | |
"--slg_mode", | |
type=str, | |
default=None, | |
choices=["original", "uncond"], | |
help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred", | |
) | |
# Flow Matching | |
parser.add_argument( | |
"--flow_shift", | |
type=float, | |
default=None, | |
help="Shift factor for flow matching schedulers. Default depends on task.", | |
) | |
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") | |
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") | |
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled") | |
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") | |
parser.add_argument( | |
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" | |
) | |
parser.add_argument( | |
"--attn_mode", | |
type=str, | |
default="torch", | |
choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"], | |
help="attention mode", | |
) | |
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") | |
parser.add_argument( | |
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" | |
) | |
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") | |
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") | |
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") | |
parser.add_argument("--compile", action="store_true", help="Enable torch.compile") | |
parser.add_argument( | |
"--compile_args", | |
nargs=4, | |
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), | |
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], | |
help="Torch.compile settings", | |
) | |
# New arguments for batch and interactive modes | |
parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") | |
parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") | |
args = parser.parse_args() | |
# Validate arguments | |
if args.from_file and args.interactive: | |
raise ValueError("Cannot use both --from_file and --interactive at the same time") | |
if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None: | |
raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified") | |
assert (args.latent_path is None or len(args.latent_path) == 0) or ( | |
args.output_type == "images" or args.output_type == "video" | |
), "latent_path is only supported for images or video output" | |
return args | |
def parse_prompt_line(line: str) -> Dict[str, Any]: | |
"""Parse a prompt line into a dictionary of argument overrides | |
Args: | |
line: Prompt line with options | |
Returns: | |
Dict[str, Any]: Dictionary of argument overrides | |
""" | |
# TODO common function with hv_train_network.line_to_prompt_dict | |
parts = line.split(" --") | |
prompt = parts[0].strip() | |
# Create dictionary of overrides | |
overrides = {"prompt": prompt} | |
for part in parts[1:]: | |
if not part.strip(): | |
continue | |
option_parts = part.split(" ", 1) | |
option = option_parts[0].strip() | |
value = option_parts[1].strip() if len(option_parts) > 1 else "" | |
# Map options to argument names | |
if option == "w": | |
overrides["video_size_width"] = int(value) | |
elif option == "h": | |
overrides["video_size_height"] = int(value) | |
elif option == "f": | |
overrides["video_length"] = int(value) | |
elif option == "d": | |
overrides["seed"] = int(value) | |
elif option == "s": | |
overrides["infer_steps"] = int(value) | |
elif option == "g" or option == "l": | |
overrides["guidance_scale"] = float(value) | |
elif option == "fs": | |
overrides["flow_shift"] = float(value) | |
elif option == "i": | |
overrides["image_path"] = value | |
elif option == "cn": | |
overrides["control_path"] = value | |
elif option == "n": | |
overrides["negative_prompt"] = value | |
return overrides | |
def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: | |
"""Apply overrides to args | |
Args: | |
args: Original arguments | |
overrides: Dictionary of overrides | |
Returns: | |
argparse.Namespace: New arguments with overrides applied | |
""" | |
args_copy = copy.deepcopy(args) | |
for key, value in overrides.items(): | |
if key == "video_size_width": | |
args_copy.video_size[1] = value | |
elif key == "video_size_height": | |
args_copy.video_size[0] = value | |
else: | |
setattr(args_copy, key, value) | |
return args_copy | |
def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]: | |
"""Return default values for each task | |
Args: | |
task: task name (t2v, t2i, i2v etc.) | |
size: size of the video (width, height) | |
Returns: | |
Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip) | |
""" | |
width, height = size if size else (0, 0) | |
if "t2i" in task: | |
return 50, 5.0, 1, False | |
elif "i2v" in task: | |
flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 | |
return 40, flow_shift, 81, True | |
else: # t2v or default | |
return 50, 5.0, 81, False | |
def setup_args(args: argparse.Namespace) -> argparse.Namespace: | |
"""Validate and set default values for optional arguments | |
Args: | |
args: command line arguments | |
Returns: | |
argparse.Namespace: updated arguments | |
""" | |
# Get default values for the task | |
infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size)) | |
# Apply default values to unset arguments | |
if args.infer_steps is None: | |
args.infer_steps = infer_steps | |
if args.flow_shift is None: | |
args.flow_shift = flow_shift | |
if args.video_length is None: | |
args.video_length = video_length | |
# Force video_length to 1 for t2i tasks | |
if "t2i" in args.task: | |
assert args.video_length == 1, f"video_length should be 1 for task {args.task}" | |
# parse slg_layers | |
if args.slg_layers is not None: | |
args.slg_layers = list(map(int, args.slg_layers.split(","))) | |
return args | |
def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]: | |
"""Validate video size and length | |
Args: | |
args: command line arguments | |
Returns: | |
Tuple[int, int, int]: (height, width, video_length) | |
""" | |
height = args.video_size[0] | |
width = args.video_size[1] | |
size = f"{width}*{height}" | |
if size not in SUPPORTED_SIZES[args.task]: | |
logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.") | |
video_length = args.video_length | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
return height, width, video_length | |
def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]: | |
"""calculate dimensions for the generation | |
Args: | |
video_size: video frame size (height, width) | |
video_length: number of frames in the video | |
config: model configuration | |
Returns: | |
Tuple[Tuple[int, int, int, int], int]: | |
((channels, frames, height, width), seq_len) | |
""" | |
height, width = video_size | |
frames = video_length | |
# calculate latent space dimensions | |
lat_f = (frames - 1) // config.vae_stride[0] + 1 | |
lat_h = height // config.vae_stride[1] | |
lat_w = width // config.vae_stride[2] | |
# calculate sequence length | |
seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f) | |
return ((16, lat_f, lat_h, lat_w), seq_len) | |
def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE: | |
"""load VAE model | |
Args: | |
args: command line arguments | |
config: model configuration | |
device: device to use | |
dtype: data type for the model | |
Returns: | |
WanVAE: loaded VAE model | |
""" | |
vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint) | |
logger.info(f"Loading VAE model from {vae_path}") | |
cache_device = torch.device("cpu") if args.vae_cache_cpu else None | |
vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device) | |
return vae | |
def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel: | |
"""load text encoder (T5) model | |
Args: | |
args: command line arguments | |
config: model configuration | |
device: device to use | |
Returns: | |
T5EncoderModel: loaded text encoder model | |
""" | |
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint) | |
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer) | |
text_encoder = T5EncoderModel( | |
text_len=config.text_len, | |
dtype=config.t5_dtype, | |
device=device, | |
checkpoint_path=checkpoint_path, | |
tokenizer_path=tokenizer_path, | |
weight_path=args.t5, | |
fp8=args.fp8_t5, | |
) | |
return text_encoder | |
def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel: | |
"""load CLIP model (for I2V only) | |
Args: | |
args: command line arguments | |
config: model configuration | |
device: device to use | |
Returns: | |
CLIPModel: loaded CLIP model | |
""" | |
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint) | |
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer) | |
clip = CLIPModel( | |
dtype=config.clip_dtype, | |
device=device, | |
checkpoint_path=checkpoint_path, | |
tokenizer_path=tokenizer_path, | |
weight_path=args.clip, | |
) | |
return clip | |
def load_dit_model( | |
args: argparse.Namespace, | |
config, | |
device: torch.device, | |
dit_dtype: torch.dtype, | |
dit_weight_dtype: Optional[torch.dtype] = None, | |
is_i2v: bool = False, | |
) -> WanModel: | |
"""load DiT model | |
Args: | |
args: command line arguments | |
config: model configuration | |
device: device to use | |
dit_dtype: data type for the model | |
dit_weight_dtype: data type for the model weights. None for as-is | |
is_i2v: I2V mode | |
Returns: | |
WanModel: loaded DiT model | |
""" | |
loading_device = "cpu" | |
if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled: | |
loading_device = device | |
loading_weight_dtype = dit_weight_dtype | |
if args.fp8_scaled or args.lora_weight is not None: | |
loading_weight_dtype = dit_dtype # load as-is | |
# do not fp8 optimize because we will merge LoRA weights | |
model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False) | |
return model | |
def merge_lora_weights(lora_module: ModuleType, model: torch.nn.Module, args: argparse.Namespace, device: torch.device) -> None: | |
"""merge LoRA weights to the model | |
Args: | |
model: DiT model | |
args: command line arguments | |
device: device to use | |
""" | |
if args.lora_weight is None or len(args.lora_weight) == 0: | |
return | |
for i, lora_weight in enumerate(args.lora_weight): | |
if args.lora_multiplier is not None and len(args.lora_multiplier) > i: | |
lora_multiplier = args.lora_multiplier[i] | |
else: | |
lora_multiplier = 1.0 | |
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") | |
weights_sd = load_file(lora_weight) | |
# apply include/exclude patterns | |
original_key_count = len(weights_sd.keys()) | |
if args.include_patterns is not None and len(args.include_patterns) > i: | |
include_pattern = args.include_patterns[i] | |
regex_include = re.compile(include_pattern) | |
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} | |
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") | |
if args.exclude_patterns is not None and len(args.exclude_patterns) > i: | |
original_key_count_ex = len(weights_sd.keys()) | |
exclude_pattern = args.exclude_patterns[i] | |
regex_exclude = re.compile(exclude_pattern) | |
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} | |
logger.info( | |
f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" | |
) | |
if len(weights_sd) != original_key_count: | |
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) | |
remaining_keys.sort() | |
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") | |
if len(weights_sd) == 0: | |
logger.warning(f"No keys left after filtering.") | |
if args.lycoris: | |
lycoris_net, _ = create_network_from_weights( | |
multiplier=lora_multiplier, | |
file=None, | |
weights_sd=weights_sd, | |
unet=model, | |
text_encoder=None, | |
vae=None, | |
for_inference=True, | |
) | |
lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) | |
else: | |
network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) | |
network.merge_to(None, model, weights_sd, device=device, non_blocking=True) | |
synchronize_device(device) | |
logger.info("LoRA weights loaded") | |
# save model here before casting to dit_weight_dtype | |
if args.save_merged_model: | |
logger.info(f"Saving merged model to {args.save_merged_model}") | |
mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory | |
logger.info("Merged model saved") | |
def optimize_model( | |
model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype | |
) -> None: | |
"""optimize the model (FP8 conversion, device move etc.) | |
Args: | |
model: dit model | |
args: command line arguments | |
device: device to use | |
dit_dtype: dtype for the model | |
dit_weight_dtype: dtype for the model weights | |
""" | |
if args.fp8_scaled: | |
# load state dict as-is and optimize to fp8 | |
state_dict = model.state_dict() | |
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) | |
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU | |
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) | |
info = model.load_state_dict(state_dict, strict=True, assign=True) | |
logger.info(f"Loaded FP8 optimized weights: {info}") | |
if args.blocks_to_swap == 0: | |
model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) | |
else: | |
# simple cast to dit_dtype | |
target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) | |
target_device = None | |
if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled | |
logger.info(f"Convert model to {dit_weight_dtype}") | |
target_dtype = dit_weight_dtype | |
if args.blocks_to_swap == 0: | |
logger.info(f"Move model to device: {device}") | |
target_device = device | |
model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations | |
if args.compile: | |
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args | |
logger.info( | |
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" | |
) | |
torch._dynamo.config.cache_size_limit = 32 | |
for i in range(len(model.blocks)): | |
model.blocks[i] = torch.compile( | |
model.blocks[i], | |
backend=compile_backend, | |
mode=compile_mode, | |
dynamic=compile_dynamic.lower() in "true", | |
fullgraph=compile_fullgraph.lower() in "true", | |
) | |
if args.blocks_to_swap > 0: | |
logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") | |
model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) | |
model.move_to_device_except_swap_blocks(device) | |
model.prepare_block_swap_before_forward() | |
else: | |
# make sure the model is on the right device | |
model.to(device) | |
model.eval().requires_grad_(False) | |
clean_memory_on_device(device) | |
def prepare_t2v_inputs( | |
args: argparse.Namespace, | |
config, | |
accelerator: Accelerator, | |
device: torch.device, | |
vae: Optional[WanVAE] = None, | |
encoded_context: Optional[Dict] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: | |
"""Prepare inputs for T2V | |
Args: | |
args: command line arguments | |
config: model configuration | |
accelerator: Accelerator instance | |
device: device to use | |
vae: VAE model for control video encoding | |
encoded_context: Pre-encoded text context | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: | |
(noise, context, context_null, (arg_c, arg_null)) | |
""" | |
# Prepare inputs for T2V | |
# calculate dimensions and sequence length | |
height, width = args.video_size | |
frames = args.video_length | |
(_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config) | |
target_shape = (16, lat_f, lat_h, lat_w) | |
# configure negative prompt | |
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt | |
# set seed | |
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) | |
if not args.cpu_noise: | |
seed_g = torch.Generator(device=device) | |
seed_g.manual_seed(seed) | |
else: | |
# ComfyUI compatible noise | |
seed_g = torch.manual_seed(seed) | |
if encoded_context is None: | |
# load text encoder | |
text_encoder = load_text_encoder(args, config, device) | |
text_encoder.model.to(device) | |
# encode prompt | |
with torch.no_grad(): | |
if args.fp8_t5: | |
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): | |
context = text_encoder([args.prompt], device) | |
context_null = text_encoder([n_prompt], device) | |
else: | |
context = text_encoder([args.prompt], device) | |
context_null = text_encoder([n_prompt], device) | |
# free text encoder and clean memory | |
del text_encoder | |
clean_memory_on_device(device) | |
else: | |
# Use pre-encoded context | |
context = encoded_context["context"] | |
context_null = encoded_context["context_null"] | |
# Fun-Control: encode control video to latent space | |
if config.is_fun_control: | |
# TODO use same resizing as for image | |
logger.info(f"Encoding control video to latent space") | |
# C, F, H, W | |
control_video = load_control_video(args.control_path, frames, height, width).to(device) | |
vae.to_device(device) | |
with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): | |
control_latent = vae.encode([control_video])[0] | |
y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent | |
vae.to_device("cpu") | |
else: | |
y = None | |
# generate noise | |
noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu") | |
noise = noise.to(device) | |
# prepare model input arguments | |
arg_c = {"context": context, "seq_len": seq_len} | |
arg_null = {"context": context_null, "seq_len": seq_len} | |
if y is not None: | |
arg_c["y"] = [y] | |
arg_null["y"] = [y] | |
return noise, context, context_null, (arg_c, arg_null) | |
def prepare_i2v_inputs( | |
args: argparse.Namespace, | |
config, | |
accelerator: Accelerator, | |
device: torch.device, | |
vae: WanVAE, | |
encoded_context: Optional[Dict] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: | |
"""Prepare inputs for I2V | |
Args: | |
args: command line arguments | |
config: model configuration | |
accelerator: Accelerator instance | |
device: device to use | |
vae: VAE model, used for image encoding | |
encoded_context: Pre-encoded text context | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: | |
(noise, context, context_null, y, (arg_c, arg_null)) | |
""" | |
# get video dimensions | |
height, width = args.video_size | |
frames = args.video_length | |
max_area = width * height | |
# load image | |
img = Image.open(args.image_path).convert("RGB") | |
# convert to numpy | |
img_cv2 = np.array(img) # PIL to numpy | |
# convert to tensor (-1 to 1) | |
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) | |
# end frame image | |
if args.end_image_path is not None: | |
end_img = Image.open(args.end_image_path).convert("RGB") | |
end_img_cv2 = np.array(end_img) # PIL to numpy | |
else: | |
end_img = None | |
end_img_cv2 = None | |
has_end_image = end_img is not None | |
# calculate latent dimensions: keep aspect ratio | |
height, width = img_tensor.shape[1:] | |
aspect_ratio = height / width | |
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1]) | |
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2]) | |
height = lat_h * config.vae_stride[1] | |
width = lat_w * config.vae_stride[2] | |
lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames | |
max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2]) | |
# set seed | |
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) | |
if not args.cpu_noise: | |
seed_g = torch.Generator(device=device) | |
seed_g.manual_seed(seed) | |
else: | |
# ComfyUI compatible noise | |
seed_g = torch.manual_seed(seed) | |
# generate noise | |
noise = torch.randn( | |
16, | |
lat_f + (1 if has_end_image else 0), | |
lat_h, | |
lat_w, | |
dtype=torch.float32, | |
generator=seed_g, | |
device=device if not args.cpu_noise else "cpu", | |
) | |
noise = noise.to(device) | |
# configure negative prompt | |
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt | |
if encoded_context is None: | |
# load text encoder | |
text_encoder = load_text_encoder(args, config, device) | |
text_encoder.model.to(device) | |
# encode prompt | |
with torch.no_grad(): | |
if args.fp8_t5: | |
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): | |
context = text_encoder([args.prompt], device) | |
context_null = text_encoder([n_prompt], device) | |
else: | |
context = text_encoder([args.prompt], device) | |
context_null = text_encoder([n_prompt], device) | |
# free text encoder and clean memory | |
del text_encoder | |
clean_memory_on_device(device) | |
# load CLIP model | |
clip = load_clip_model(args, config, device) | |
clip.model.to(device) | |
# encode image to CLIP context | |
logger.info(f"Encoding image to CLIP context") | |
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): | |
clip_context = clip.visual([img_tensor[:, None, :, :]]) | |
logger.info(f"Encoding complete") | |
# free CLIP model and clean memory | |
del clip | |
clean_memory_on_device(device) | |
else: | |
# Use pre-encoded context | |
context = encoded_context["context"] | |
context_null = encoded_context["context_null"] | |
clip_context = encoded_context["clip_context"] | |
# encode image to latent space with VAE | |
logger.info(f"Encoding image to latent space") | |
vae.to_device(device) | |
# resize image | |
interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC | |
img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation) | |
img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW | |
img_resized = img_resized.unsqueeze(1) # CFHW | |
if has_end_image: | |
interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC | |
end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation) | |
end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW | |
end_img_resized = end_img_resized.unsqueeze(1) # CFHW | |
# create mask for the first frame | |
msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device) | |
msk[:, 0] = 1 | |
if has_end_image: | |
msk[:, -1] = 1 | |
# encode image to latent space | |
with accelerator.autocast(), torch.no_grad(): | |
# padding to match the required number of frames | |
padding_frames = frames - 1 # the first frame is image | |
img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1) | |
y = vae.encode([img_resized])[0] | |
if has_end_image: | |
y_end = vae.encode([end_img_resized])[0] | |
y = torch.concat([y, y_end], dim=1) # add end frame | |
y = torch.concat([msk, y]) | |
logger.info(f"Encoding complete") | |
# Fun-Control: encode control video to latent space | |
if config.is_fun_control: | |
# TODO use same resizing as for image | |
logger.info(f"Encoding control video to latent space") | |
# C, F, H, W | |
control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device) | |
with accelerator.autocast(), torch.no_grad(): | |
control_latent = vae.encode([control_video])[0] | |
y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it | |
if has_end_image: | |
y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work | |
else: | |
y[:, 1:] = 0 # remove image latent except first frame | |
y = torch.concat([control_latent, y], dim=0) # add control video latent | |
# prepare model input arguments | |
arg_c = { | |
"context": [context[0]], | |
"clip_fea": clip_context, | |
"seq_len": max_seq_len, | |
"y": [y], | |
} | |
arg_null = { | |
"context": context_null, | |
"clip_fea": clip_context, | |
"seq_len": max_seq_len, | |
"y": [y], | |
} | |
vae.to_device("cpu") # move VAE to CPU to save memory | |
clean_memory_on_device(device) | |
return noise, context, context_null, y, (arg_c, arg_null) | |
def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor: | |
"""load control video to latent space | |
Args: | |
control_path: path to control video | |
frames: number of frames in the video | |
height: height of the video | |
width: width of the video | |
Returns: | |
torch.Tensor: control video latent, CFHW | |
""" | |
logger.info(f"Load control video from {control_path}") | |
video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames | |
if len(video) < frames: | |
raise ValueError(f"Video length is less than {frames}") | |
# video = np.stack(video, axis=0) # F, H, W, C | |
video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1 | |
video = video.permute(1, 0, 2, 3) # C, F, H, W | |
return video | |
def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: | |
"""setup scheduler for sampling | |
Args: | |
args: command line arguments | |
config: model configuration | |
device: device to use | |
Returns: | |
Tuple[Any, torch.Tensor]: (scheduler, timesteps) | |
""" | |
if args.sample_solver == "unipc": | |
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) | |
scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) | |
timesteps = scheduler.timesteps | |
elif args.sample_solver == "dpm++": | |
scheduler = FlowDPMSolverMultistepScheduler( | |
num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False | |
) | |
sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) | |
timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) | |
elif args.sample_solver == "vanilla": | |
scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) | |
scheduler.set_timesteps(args.infer_steps, device=device) | |
timesteps = scheduler.timesteps | |
# FlowMatchDiscreteScheduler does not support generator argument in step method | |
org_step = scheduler.step | |
def step_wrapper( | |
model_output: torch.Tensor, | |
timestep: Union[int, torch.Tensor], | |
sample: torch.Tensor, | |
return_dict: bool = True, | |
generator=None, | |
): | |
return org_step(model_output, timestep, sample, return_dict=return_dict) | |
scheduler.step = step_wrapper | |
else: | |
raise NotImplementedError("Unsupported solver.") | |
return scheduler, timesteps | |
def run_sampling( | |
model: WanModel, | |
noise: torch.Tensor, | |
scheduler: Any, | |
timesteps: torch.Tensor, | |
args: argparse.Namespace, | |
inputs: Tuple[dict, dict], | |
device: torch.device, | |
seed_g: torch.Generator, | |
accelerator: Accelerator, | |
is_i2v: bool = False, | |
use_cpu_offload: bool = True, | |
) -> torch.Tensor: | |
"""run sampling | |
Args: | |
model: dit model | |
noise: initial noise | |
scheduler: scheduler for sampling | |
timesteps: time steps for sampling | |
args: command line arguments | |
inputs: model input (arg_c, arg_null) | |
device: device to use | |
seed_g: random generator | |
accelerator: Accelerator instance | |
is_i2v: I2V mode (False means T2V mode) | |
use_cpu_offload: Whether to offload tensors to CPU during processing | |
Returns: | |
torch.Tensor: generated latent | |
""" | |
arg_c, arg_null = inputs | |
latent = noise | |
latent_storage_device = device if not use_cpu_offload else "cpu" | |
latent = latent.to(latent_storage_device) | |
# cfg skip | |
apply_cfg_array = [] | |
num_timesteps = len(timesteps) | |
if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None: | |
# Calculate thresholds based on cfg_apply_ratio | |
apply_steps = int(num_timesteps * args.cfg_apply_ratio) | |
if args.cfg_skip_mode == "early": | |
# Skip CFG in early steps, apply in late steps | |
start_index = num_timesteps - apply_steps | |
end_index = num_timesteps | |
elif args.cfg_skip_mode == "late": | |
# Skip CFG in late steps, apply in early steps | |
start_index = 0 | |
end_index = apply_steps | |
elif args.cfg_skip_mode == "early_late": | |
# Skip CFG in early and late steps, apply in middle steps | |
start_index = (num_timesteps - apply_steps) // 2 | |
end_index = start_index + apply_steps | |
elif args.cfg_skip_mode == "middle": | |
# Skip CFG in middle steps, apply in early and late steps | |
skip_steps = num_timesteps - apply_steps | |
middle_start = (num_timesteps - skip_steps) // 2 | |
middle_end = middle_start + skip_steps | |
w = 0.0 | |
for step_idx in range(num_timesteps): | |
if args.cfg_skip_mode == "alternate": | |
# accumulate w and apply CFG when w >= 1.0 | |
w += args.cfg_apply_ratio | |
apply = w >= 1.0 | |
if apply: | |
w -= 1.0 | |
elif args.cfg_skip_mode == "middle": | |
# Skip CFG in early and late steps, apply in middle steps | |
apply = step_idx < middle_start or step_idx >= middle_end | |
else: | |
# Apply CFG on some steps based on ratio | |
apply = step_idx >= start_index and step_idx < end_index | |
apply_cfg_array.append(apply) | |
pattern = ["A" if apply else "S" for apply in apply_cfg_array] | |
pattern = "".join(pattern) | |
logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}") | |
else: | |
# Apply CFG on all steps | |
apply_cfg_array = [True] * num_timesteps | |
# SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py | |
slg_start_step = int(args.slg_start * num_timesteps) | |
slg_end_step = int(args.slg_end * num_timesteps) | |
for i, t in enumerate(tqdm(timesteps)): | |
# latent is on CPU if use_cpu_offload is True | |
latent_model_input = [latent.to(device)] | |
timestep = torch.stack([t]).to(device) | |
with accelerator.autocast(), torch.no_grad(): | |
noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device) | |
apply_cfg = apply_cfg_array[i] # apply CFG or not | |
if apply_cfg: | |
apply_slg = i >= slg_start_step and i < slg_end_step | |
# print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}") | |
if args.slg_mode == "original" and apply_slg: | |
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device) | |
# apply guidance | |
# SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale | |
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
# calculate skip layer out | |
skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to( | |
latent_storage_device | |
) | |
# apply skip layer guidance | |
# SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg | |
noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out) | |
elif args.slg_mode == "uncond" and apply_slg: | |
# noise_pred_uncond is skip layer out | |
noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to( | |
latent_storage_device | |
) | |
# apply guidance | |
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
else: | |
# normal guidance | |
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device) | |
# apply guidance | |
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
else: | |
noise_pred = noise_pred_cond | |
# step | |
latent_input = latent.unsqueeze(0) | |
temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0] | |
# update latent | |
latent = temp_x0.squeeze(0) | |
return latent | |
def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor: | |
"""main function for generation | |
Args: | |
args: command line arguments | |
shared_models: dictionary containing pre-loaded models and encoded data | |
Returns: | |
torch.Tensor: generated latent | |
""" | |
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( | |
gen_settings.device, | |
gen_settings.cfg, | |
gen_settings.dit_dtype, | |
gen_settings.dit_weight_dtype, | |
gen_settings.vae_dtype, | |
) | |
# prepare accelerator | |
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" | |
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) | |
# I2V or T2V | |
is_i2v = "i2v" in args.task | |
# prepare seed | |
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) | |
args.seed = seed # set seed to args for saving | |
# Check if we have shared models | |
if shared_models is not None: | |
# Use shared models and encoded data | |
vae = shared_models.get("vae") | |
model = shared_models.get("model") | |
encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt) | |
# prepare inputs | |
if is_i2v: | |
# I2V | |
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context) | |
else: | |
# T2V | |
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context) | |
else: | |
# prepare inputs without shared models | |
if is_i2v: | |
# I2V: need text encoder, VAE and CLIP | |
vae = load_vae(args, cfg, device, vae_dtype) | |
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae) | |
# vae is on CPU after prepare_i2v_inputs | |
else: | |
# T2V: need text encoder | |
vae = None | |
if cfg.is_fun_control: | |
# Fun-Control: need VAE for encoding control video | |
vae = load_vae(args, cfg, device, vae_dtype) | |
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae) | |
# load DiT model | |
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) | |
# merge LoRA weights | |
if args.lora_weight is not None and len(args.lora_weight) > 0: | |
merge_lora_weights(lora_wan, model, args, device) | |
# if we only want to save the model, we can skip the rest | |
if args.save_merged_model: | |
return None | |
# optimize model: fp8 conversion, block swap etc. | |
optimize_model(model, args, device, dit_dtype, dit_weight_dtype) | |
# setup scheduler | |
scheduler, timesteps = setup_scheduler(args, cfg, device) | |
# set random generator | |
seed_g = torch.Generator(device=device) | |
seed_g.manual_seed(seed) | |
# run sampling | |
latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v) | |
# Only clean up shared models if they were created within this function | |
if shared_models is None: | |
# free memory | |
del model | |
del scheduler | |
synchronize_device(device) | |
# wait for 5 seconds until block swap is done | |
logger.info("Waiting for 5 seconds to finish block swap") | |
time.sleep(5) | |
gc.collect() | |
clean_memory_on_device(device) | |
# save VAE model for decoding | |
if vae is None: | |
args._vae = None | |
else: | |
args._vae = vae | |
return latent | |
def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor: | |
"""decode latent | |
Args: | |
latent: latent tensor | |
args: command line arguments | |
cfg: model configuration | |
Returns: | |
torch.Tensor: decoded video or image | |
""" | |
device = torch.device(args.device) | |
# load VAE model or use the one from the generation | |
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16 | |
if hasattr(args, "_vae") and args._vae is not None: | |
vae = args._vae | |
else: | |
vae = load_vae(args, cfg, device, vae_dtype) | |
vae.to_device(device) | |
logger.info(f"Decoding video from latents: {latent.shape}") | |
x0 = latent.to(device) | |
with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad(): | |
videos = vae.decode(x0) | |
# some tail frames may be corrupted when end frame is used, we add an option to remove them | |
if args.trim_tail_frames: | |
videos[0] = videos[0][:, : -args.trim_tail_frames] | |
logger.info(f"Decoding complete") | |
video = videos[0] | |
del videos | |
video = video.to(torch.float32).cpu() | |
return video | |
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: | |
"""Save latent to file | |
Args: | |
latent: latent tensor | |
args: command line arguments | |
height: height of frame | |
width: width of frame | |
Returns: | |
str: Path to saved latent file | |
""" | |
save_path = args.save_path | |
os.makedirs(save_path, exist_ok=True) | |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
seed = args.seed | |
video_length = args.video_length | |
latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" | |
if args.no_metadata: | |
metadata = None | |
else: | |
metadata = { | |
"seeds": f"{seed}", | |
"prompt": f"{args.prompt}", | |
"height": f"{height}", | |
"width": f"{width}", | |
"video_length": f"{video_length}", | |
"infer_steps": f"{args.infer_steps}", | |
"guidance_scale": f"{args.guidance_scale}", | |
} | |
if args.negative_prompt is not None: | |
metadata["negative_prompt"] = f"{args.negative_prompt}" | |
sd = {"latent": latent} | |
save_file(sd, latent_path, metadata=metadata) | |
logger.info(f"Latent saved to: {latent_path}") | |
return latent_path | |
def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: | |
"""Save video to file | |
Args: | |
video: Video tensor | |
args: command line arguments | |
original_base_name: Original base name (if latents are loaded from files) | |
Returns: | |
str: Path to saved video file | |
""" | |
save_path = args.save_path | |
os.makedirs(save_path, exist_ok=True) | |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
seed = args.seed | |
original_name = "" if original_base_name is None else f"_{original_base_name}" | |
video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4" | |
video = video.unsqueeze(0) | |
save_videos_grid(video, video_path, fps=args.fps, rescale=True) | |
logger.info(f"Video saved to: {video_path}") | |
return video_path | |
def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: | |
"""Save images to directory | |
Args: | |
sample: Video tensor | |
args: command line arguments | |
original_base_name: Original base name (if latents are loaded from files) | |
Returns: | |
str: Path to saved images directory | |
""" | |
save_path = args.save_path | |
os.makedirs(save_path, exist_ok=True) | |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
seed = args.seed | |
original_name = "" if original_base_name is None else f"_{original_base_name}" | |
image_name = f"{time_flag}_{seed}{original_name}" | |
sample = sample.unsqueeze(0) | |
save_images_grid(sample, save_path, image_name, rescale=True) | |
logger.info(f"Sample images saved to: {save_path}/{image_name}") | |
return f"{save_path}/{image_name}" | |
def save_output( | |
latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None | |
) -> None: | |
"""save output | |
Args: | |
latent: latent tensor | |
args: command line arguments | |
cfg: model configuration | |
height: height of frame | |
width: width of frame | |
original_base_names: original base names (if latents are loaded from files) | |
""" | |
if args.output_type == "latent" or args.output_type == "both": | |
# save latent | |
save_latent(latent, args, height, width) | |
if args.output_type == "video" or args.output_type == "both": | |
# save video | |
sample = decode_latent(latent.unsqueeze(0), args, cfg) | |
original_name = "" if original_base_names is None else f"_{original_base_names[0]}" | |
save_video(sample, args, original_name) | |
elif args.output_type == "images": | |
# save images | |
sample = decode_latent(latent.unsqueeze(0), args, cfg) | |
original_name = "" if original_base_names is None else f"_{original_base_names[0]}" | |
save_images(sample, args, original_name) | |
def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: | |
"""Process multiple prompts for batch mode | |
Args: | |
prompt_lines: List of prompt lines | |
base_args: Base command line arguments | |
Returns: | |
List[Dict]: List of prompt data dictionaries | |
""" | |
prompts_data = [] | |
for line in prompt_lines: | |
line = line.strip() | |
if not line or line.startswith("#"): # Skip empty lines and comments | |
continue | |
# Parse prompt line and create override dictionary | |
prompt_data = parse_prompt_line(line) | |
logger.info(f"Parsed prompt data: {prompt_data}") | |
prompts_data.append(prompt_data) | |
return prompts_data | |
def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None: | |
"""Process multiple prompts with model reuse | |
Args: | |
prompts_data: List of prompt data dictionaries | |
args: Base command line arguments | |
""" | |
if not prompts_data: | |
logger.warning("No valid prompts found") | |
return | |
# 1. Load configuration | |
gen_settings = get_generation_settings(args) | |
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( | |
gen_settings.device, | |
gen_settings.cfg, | |
gen_settings.dit_dtype, | |
gen_settings.dit_weight_dtype, | |
gen_settings.vae_dtype, | |
) | |
is_i2v = "i2v" in args.task | |
# 2. Encode all prompts | |
logger.info("Loading text encoder to encode all prompts") | |
text_encoder = load_text_encoder(args, cfg, device) | |
text_encoder.model.to(device) | |
encoded_contexts = {} | |
with torch.no_grad(): | |
for prompt_data in prompts_data: | |
prompt = prompt_data["prompt"] | |
prompt_args = apply_overrides(args, prompt_data) | |
n_prompt = prompt_data.get( | |
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt | |
) | |
if args.fp8_t5: | |
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype): | |
context = text_encoder([prompt], device) | |
context_null = text_encoder([n_prompt], device) | |
else: | |
context = text_encoder([prompt], device) | |
context_null = text_encoder([n_prompt], device) | |
encoded_contexts[prompt] = {"context": context, "context_null": context_null} | |
# Free text encoder and clean memory | |
del text_encoder | |
clean_memory_on_device(device) | |
# 3. Process I2V additional encodings if needed | |
vae = None | |
if is_i2v: | |
logger.info("Loading VAE and CLIP for I2V preprocessing") | |
vae = load_vae(args, cfg, device, vae_dtype) | |
vae.to_device(device) | |
clip = load_clip_model(args, cfg, device) | |
clip.model.to(device) | |
# Process each image and encode with CLIP | |
for prompt_data in prompts_data: | |
if "image_path" not in prompt_data: | |
continue | |
prompt_args = apply_overrides(args, prompt_data) | |
if not os.path.exists(prompt_args.image_path): | |
logger.warning(f"Image path not found: {prompt_args.image_path}") | |
continue | |
# Load and encode image with CLIP | |
img = Image.open(prompt_args.image_path).convert("RGB") | |
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) | |
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): | |
clip_context = clip.visual([img_tensor[:, None, :, :]]) | |
encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context | |
# Free CLIP and clean memory | |
del clip | |
clean_memory_on_device(device) | |
# Keep VAE in CPU memory for later use | |
vae.to_device("cpu") | |
elif cfg.is_fun_control: | |
# For Fun-Control, we need VAE but keep it on CPU | |
vae = load_vae(args, cfg, device, vae_dtype) | |
vae.to_device("cpu") | |
# 4. Load DiT model | |
logger.info("Loading DiT model") | |
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) | |
# 5. Merge LoRA weights if needed | |
if args.lora_weight is not None and len(args.lora_weight) > 0: | |
merge_lora_weights(lora_wan, model, args, device) | |
if args.save_merged_model: | |
logger.info("Model merged and saved. Exiting.") | |
return | |
# 6. Optimize model | |
optimize_model(model, args, device, dit_dtype, dit_weight_dtype) | |
# Create shared models dict for generate function | |
shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts} | |
# 7. Generate for each prompt | |
all_latents = [] | |
all_prompt_args = [] | |
for i, prompt_data in enumerate(prompts_data): | |
logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...") | |
# Apply overrides for this prompt | |
prompt_args = apply_overrides(args, prompt_data) | |
# Generate latent | |
latent = generate(prompt_args, gen_settings, shared_models) | |
# Save latent if needed | |
height, width, _ = check_inputs(prompt_args) | |
if prompt_args.output_type == "latent" or prompt_args.output_type == "both": | |
save_latent(latent, prompt_args, height, width) | |
all_latents.append(latent) | |
all_prompt_args.append(prompt_args) | |
# 8. Free DiT model | |
del model | |
clean_memory_on_device(device) | |
synchronize_device(device) | |
# wait for 5 seconds until block swap is done | |
logger.info("Waiting for 5 seconds to finish block swap") | |
time.sleep(5) | |
gc.collect() | |
clean_memory_on_device(device) | |
# 9. Decode latents if needed | |
if args.output_type != "latent": | |
logger.info("Decoding latents to videos/images") | |
if vae is None: | |
vae = load_vae(args, cfg, device, vae_dtype) | |
vae.to_device(device) | |
for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)): | |
logger.info(f"Decoding output {i+1}/{len(all_latents)}") | |
# Decode latent | |
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg) | |
# Save as video or images | |
if prompt_args.output_type == "video" or prompt_args.output_type == "both": | |
save_video(video, prompt_args) | |
elif prompt_args.output_type == "images": | |
save_images(video, prompt_args) | |
# Free VAE | |
del vae | |
clean_memory_on_device(device) | |
gc.collect() | |
def process_interactive(args: argparse.Namespace) -> None: | |
"""Process prompts in interactive mode | |
Args: | |
args: Base command line arguments | |
""" | |
gen_settings = get_generation_settings(args) | |
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( | |
gen_settings.device, | |
gen_settings.cfg, | |
gen_settings.dit_dtype, | |
gen_settings.dit_weight_dtype, | |
gen_settings.vae_dtype, | |
) | |
is_i2v = "i2v" in args.task | |
# Initialize models to None | |
text_encoder = None | |
vae = None | |
model = None | |
clip = None | |
print("Interactive mode. Enter prompts (Ctrl+D to exit):") | |
try: | |
while True: | |
try: | |
line = input("> ") | |
if not line.strip(): | |
continue | |
# Parse prompt | |
prompt_data = parse_prompt_line(line) | |
prompt_args = apply_overrides(args, prompt_data) | |
# Ensure we have all the models we need | |
# 1. Load text encoder if not already loaded | |
if text_encoder is None: | |
logger.info("Loading text encoder") | |
text_encoder = load_text_encoder(args, cfg, device) | |
text_encoder.model.to(device) | |
# Encode prompt | |
n_prompt = prompt_data.get( | |
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt | |
) | |
with torch.no_grad(): | |
if args.fp8_t5: | |
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype): | |
context = text_encoder([prompt_data["prompt"]], device) | |
context_null = text_encoder([n_prompt], device) | |
else: | |
context = text_encoder([prompt_data["prompt"]], device) | |
context_null = text_encoder([n_prompt], device) | |
encoded_context = {"context": context, "context_null": context_null} | |
# Move text encoder to CPU after use | |
text_encoder.model.to("cpu") | |
# 2. For I2V, we need CLIP and VAE | |
if is_i2v: | |
if clip is None: | |
logger.info("Loading CLIP model") | |
clip = load_clip_model(args, cfg, device) | |
clip.model.to(device) | |
# Encode image with CLIP if there's an image path | |
if prompt_args.image_path and os.path.exists(prompt_args.image_path): | |
img = Image.open(prompt_args.image_path).convert("RGB") | |
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) | |
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): | |
clip_context = clip.visual([img_tensor[:, None, :, :]]) | |
encoded_context["clip_context"] = clip_context | |
# Move CLIP to CPU after use | |
clip.model.to("cpu") | |
# Load VAE if needed | |
if vae is None: | |
logger.info("Loading VAE model") | |
vae = load_vae(args, cfg, device, vae_dtype) | |
elif cfg.is_fun_control and vae is None: | |
# For Fun-Control, we need VAE | |
logger.info("Loading VAE model for Fun-Control") | |
vae = load_vae(args, cfg, device, vae_dtype) | |
# 3. Load DiT model if not already loaded | |
if model is None: | |
logger.info("Loading DiT model") | |
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) | |
# Merge LoRA weights if needed | |
if args.lora_weight is not None and len(args.lora_weight) > 0: | |
merge_lora_weights(lora_wan, model, args, device) | |
# Optimize model | |
optimize_model(model, args, device, dit_dtype, dit_weight_dtype) | |
else: | |
# Move model to GPU if it was offloaded | |
model.to(device) | |
# Create shared models dict | |
shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}} | |
# Generate latent | |
latent = generate(prompt_args, gen_settings, shared_models) | |
# Move model to CPU after generation | |
model.to("cpu") | |
# Save latent if needed | |
height, width, _ = check_inputs(prompt_args) | |
if prompt_args.output_type == "latent" or prompt_args.output_type == "both": | |
save_latent(latent, prompt_args, height, width) | |
# Decode and save output | |
if prompt_args.output_type != "latent": | |
if vae is None: | |
vae = load_vae(args, cfg, device, vae_dtype) | |
vae.to_device(device) | |
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg) | |
if prompt_args.output_type == "video" or prompt_args.output_type == "both": | |
save_video(video, prompt_args) | |
elif prompt_args.output_type == "images": | |
save_images(video, prompt_args) | |
# Move VAE to CPU after use | |
vae.to_device("cpu") | |
clean_memory_on_device(device) | |
except KeyboardInterrupt: | |
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") | |
continue | |
except EOFError: | |
print("\nExiting interactive mode") | |
# Clean up all models | |
if text_encoder is not None: | |
del text_encoder | |
if clip is not None: | |
del clip | |
if vae is not None: | |
del vae | |
if model is not None: | |
del model | |
clean_memory_on_device(device) | |
gc.collect() | |
def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: | |
device = torch.device(args.device) | |
cfg = WAN_CONFIGS[args.task] | |
# select dtype | |
dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16 | |
if dit_dtype.itemsize == 1: | |
# if weight is in fp8, use bfloat16 for DiT (input/output) | |
dit_dtype = torch.bfloat16 | |
if args.fp8_scaled: | |
raise ValueError( | |
"DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください" | |
) | |
dit_weight_dtype = dit_dtype # default | |
if args.fp8_scaled: | |
dit_weight_dtype = None # various precision weights, so don't cast to specific dtype | |
elif args.fp8: | |
dit_weight_dtype = torch.float8_e4m3fn | |
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype | |
logger.info( | |
f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}" | |
) | |
gen_settings = GenerationSettings( | |
device=device, | |
cfg=cfg, | |
dit_dtype=dit_dtype, | |
dit_weight_dtype=dit_weight_dtype, | |
vae_dtype=vae_dtype, | |
) | |
return gen_settings | |
def main(): | |
# Parse arguments | |
args = parse_args() | |
# Check if latents are provided | |
latents_mode = args.latent_path is not None and len(args.latent_path) > 0 | |
# Set device | |
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" | |
device = torch.device(device) | |
logger.info(f"Using device: {device}") | |
args.device = device | |
if latents_mode: | |
# Original latent decode mode | |
cfg = WAN_CONFIGS[args.task] # any task is fine | |
original_base_names = [] | |
latents_list = [] | |
seeds = [] | |
assert len(args.latent_path) == 1, "Only one latent path is supported for now" | |
for latent_path in args.latent_path: | |
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) | |
seed = 0 | |
if os.path.splitext(latent_path)[1] != ".safetensors": | |
latents = torch.load(latent_path, map_location="cpu") | |
else: | |
latents = load_file(latent_path)["latent"] | |
with safe_open(latent_path, framework="pt") as f: | |
metadata = f.metadata() | |
if metadata is None: | |
metadata = {} | |
logger.info(f"Loaded metadata: {metadata}") | |
if "seeds" in metadata: | |
seed = int(metadata["seeds"]) | |
if "height" in metadata and "width" in metadata: | |
height = int(metadata["height"]) | |
width = int(metadata["width"]) | |
args.video_size = [height, width] | |
if "video_length" in metadata: | |
args.video_length = int(metadata["video_length"]) | |
seeds.append(seed) | |
latents_list.append(latents) | |
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") | |
latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape | |
height = latents.shape[-2] | |
width = latents.shape[-1] | |
height *= cfg.patch_size[1] * cfg.vae_stride[1] | |
width *= cfg.patch_size[2] * cfg.vae_stride[2] | |
video_length = latents.shape[1] | |
video_length = (video_length - 1) * cfg.vae_stride[0] + 1 | |
args.seed = seeds[0] | |
# Decode and save | |
save_output(latent[0], args, cfg, height, width, original_base_names) | |
elif args.from_file: | |
# Batch mode from file | |
args = setup_args(args) | |
# Read prompts from file | |
with open(args.from_file, "r", encoding="utf-8") as f: | |
prompt_lines = f.readlines() | |
# Process prompts | |
prompts_data = preprocess_prompts_for_batch(prompt_lines, args) | |
process_batch_prompts(prompts_data, args) | |
elif args.interactive: | |
# Interactive mode | |
args = setup_args(args) | |
process_interactive(args) | |
else: | |
# Single prompt mode (original behavior) | |
args = setup_args(args) | |
height, width, video_length = check_inputs(args) | |
logger.info( | |
f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, " | |
f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}" | |
) | |
# Generate latent | |
gen_settings = get_generation_settings(args) | |
latent = generate(args, gen_settings) | |
# Make sure the model is freed from GPU memory | |
gc.collect() | |
clean_memory_on_device(args.device) | |
# Save latent and video | |
if args.save_merged_model: | |
return | |
# Add batch dimension | |
latent = latent.unsqueeze(0) | |
save_output(latent[0], args, WAN_CONFIGS[args.task], height, width) | |
logger.info("Done!") | |
if __name__ == "__main__": | |
main() | |