Framepack-H111 / framepack_generate_video.py
rahul7star's picture
Upload 303 files
e0336bc verified
# Combined and Corrected Script
#!/usr/bin/env python3
import argparse
import os
import sys
import time
import random
import traceback
from datetime import datetime
from pathlib import Path
import re # For parsing section args
import einops
import numpy as np
import torch
import av # For saving video (used by save_bcthw_as_mp4)
from PIL import Image
from tqdm import tqdm
import cv2
# --- Dependencies from diffusers_helper ---
# Ensure this library is installed or in the PYTHONPATH
try:
# from diffusers_helper.hf_login import login # Not strictly needed for inference if models public/cached
from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode #, vae_decode_fake # vae_decode_fake not used here
from diffusers_helper.utils import (save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw,
resize_and_center_crop, generate_timestamp)
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import (cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation,
offload_model_from_device_for_memory_preservation, fake_diffusers_current_device,
DynamicSwapInstaller, unload_complete_models, load_model_as_complete)
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket#, bucket_options # bucket_options no longer needed here
except ImportError:
print("Error: Could not import modules from 'diffusers_helper'.")
print("Please ensure the 'diffusers_helper' library is installed and accessible.")
print("You might need to clone the repository and add it to your PYTHONPATH.")
sys.exit(1)
# --- End Dependencies ---
from diffusers import AutoencoderKLHunyuanVideo
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
from transformers import SiglipImageProcessor, SiglipVisionModel
# --- Constants ---
DIMENSION_MULTIPLE = 16 # VAE and model constraints often require divisibility by 8 or 16. 16 is safer.
SECTION_ARG_PATTERN = re.compile(r"^(\d+):([^:]+)(?::(.*))?$") # Regex for section arg: number:image_path[:prompt]
def parse_section_args(section_strings):
""" Parses the --section arguments into a dictionary. """
section_data = {}
if not section_strings:
return section_data
for section_str in section_strings:
match = SECTION_ARG_PATTERN.match(section_str)
if not match:
print(f"Warning: Invalid section format: '{section_str}'. Expected 'number:image_path[:prompt]'. Skipping.")
continue
section_index_str, image_path, prompt_text = match.groups()
section_index = int(section_index_str)
prompt_text = prompt_text if prompt_text else None
if not os.path.exists(image_path):
print(f"Warning: Image path for section {section_index} ('{image_path}') not found. Skipping section.")
continue
if section_index in section_data:
print(f"Warning: Duplicate section index {section_index}. Overwriting previous entry.")
section_data[section_index] = (image_path, prompt_text)
print(f"Parsed section {section_index}: Image='{image_path}', Prompt='{prompt_text}'")
return section_data
def parse_args():
parser = argparse.ArgumentParser(description="FramePack HunyuanVideo inference script (CLI version with Advanced End Frame & Section Control)")
# --- Model Paths ---
parser.add_argument('--transformer_path', type=str, default='lllyasviel/FramePackI2V_HY', help="Path to the FramePack Transformer model")
parser.add_argument('--vae_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the VAE model directory")
parser.add_argument('--text_encoder_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the Llama text encoder directory")
parser.add_argument('--text_encoder_2_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the CLIP text encoder directory")
parser.add_argument('--image_encoder_path', type=str, default='lllyasviel/flux_redux_bfl', help="Path to the SigLIP image encoder directory")
parser.add_argument('--hf_home', type=str, default='./hf_download', help="Directory to download/cache Hugging Face models")
# --- Input ---
parser.add_argument("--input_image", type=str, required=True, help="Path to the input image (start frame)")
parser.add_argument("--end_frame", type=str, default=None, help="Path to the optional end frame image (video end)")
parser.add_argument("--prompt", type=str, required=True, help="Default prompt for generation")
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for generation")
# <<< START: Modified Arguments for End Frame >>>
parser.add_argument("--end_frame_weight", type=float, default=0.3, help="End frame influence weight (0.0-1.0) for blending modes ('half', 'progressive'). Higher blends more end frame *conditioning latent*.") # Default lowered further
parser.add_argument("--end_frame_influence", type=str, default="last",
choices=["last", "half", "progressive", "bookend"],
help="How to use the global end frame: 'last' (uses end frame for initial context only, no latent blending), 'half' (blends start/end conditioning latents for second half of video), 'progressive' (gradually blends conditioning latents from end to start), 'bookend' (uses end frame conditioning latent ONLY for first generated section IF no section keyframe set, no blending otherwise). All modes use start image embedding.") # Help text updated
# <<< END: Modified Arguments for End Frame >>>
# <<< START: New Arguments for Section Control >>>
parser.add_argument("--section", type=str, action='append',
help="Define a keyframe section. Format: 'index:image_path[:prompt]'. Index 0 is the last generated section (video start), 1 is second last, etc. Repeat for multiple sections. Example: --section 0:path/to/start_like.png:'A sunrise' --section 2:path/to/mid.png")
# <<< END: New Arguments for Section Control >>>
# --- Output Resolution (Choose ONE method) ---
parser.add_argument("--target_resolution", type=int, default=None, help=f"Target resolution for the longer side for automatic aspect ratio calculation (bucketing). Used if --width and --height are not specified. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.")
parser.add_argument("--width", type=int, default=None, help=f"Explicit target width for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.")
parser.add_argument("--height", type=int, default=None, help=f"Explicit target height for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.")
# --- Output ---
parser.add_argument("--save_path", type=str, required=True, help="Directory to save the generated video")
parser.add_argument("--save_intermediate_sections", action='store_true', help="Save the video after each section is generated and decoded.")
parser.add_argument("--save_section_final_frames", action='store_true', help="Save the final decoded frame of each generated section as a PNG image.")
# --- Generation Parameters (Matching Gradio Demo Defaults where applicable) ---
parser.add_argument("--seed", type=int, default=None, help="Seed for generation. Random if not set.")
parser.add_argument("--total_second_length", type=float, default=5.0, help="Total desired video length in seconds")
parser.add_argument("--fps", type=int, default=30, help="Frames per second for the output video")
parser.add_argument("--steps", type=int, default=25, help="Number of inference steps (changing not recommended)")
parser.add_argument("--distilled_guidance_scale", "--gs", type=float, default=10.0, help="Distilled CFG Scale (gs)")
parser.add_argument("--cfg", type=float, default=1.0, help="Classifier-Free Guidance Scale (fixed at 1.0 for FramePack usually)")
parser.add_argument("--rs", type=float, default=0.0, help="CFG Rescale (fixed at 0.0 for FramePack usually)")
parser.add_argument("--latent_window_size", type=int, default=9, help="Latent window size (changing not recommended)")
# --- Performance / Memory ---
parser.add_argument('--high_vram', action='store_true', help="Force high VRAM mode (loads all models to GPU)")
parser.add_argument('--low_vram', action='store_true', help="Force low VRAM mode (uses dynamic swapping)")
parser.add_argument("--gpu_memory_preservation", type=float, default=6.0, help="GPU memory (GB) to preserve when offloading (low VRAM mode)")
parser.add_argument('--use_teacache', action='store_true', default=True, help="Use TeaCache optimization (default: True)")
parser.add_argument('--no_teacache', action='store_false', dest='use_teacache', help="Disable TeaCache optimization")
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda', 'cpu'). Auto-detects if None.")
args = parser.parse_args()
# --- Argument Validation ---
if args.seed is None:
args.seed = random.randint(0, 2**32 - 1)
print(f"Generated random seed: {args.seed}")
if args.width is not None and args.height is not None:
if args.width <= 0 or args.height <= 0:
print(f"Error: Explicit --width ({args.width}) and --height ({args.height}) must be positive.")
sys.exit(1)
if args.target_resolution is not None:
print("Warning: Both --width/--height and --target_resolution specified. Using explicit --width and --height.")
args.target_resolution = None
elif args.target_resolution is not None:
if args.target_resolution <= 0:
print(f"Error: --target_resolution ({args.target_resolution}) must be positive.")
sys.exit(1)
if args.width is not None or args.height is not None:
print("Error: Cannot specify --target_resolution with only one of --width or --height. Provide both or neither.")
sys.exit(1)
else:
print(f"Warning: No resolution specified. Defaulting to --target_resolution 640.")
args.target_resolution = 640
if args.end_frame_weight < 0.0 or args.end_frame_weight > 1.0:
print(f"Error: --end_frame_weight must be between 0.0 and 1.0 (got {args.end_frame_weight}).")
sys.exit(1)
if args.width is not None and args.width % DIMENSION_MULTIPLE != 0:
print(f"Warning: Specified --width ({args.width}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.")
if args.height is not None and args.height % DIMENSION_MULTIPLE != 0:
print(f"Warning: Specified --height ({args.height}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.")
if args.target_resolution is not None and args.target_resolution % DIMENSION_MULTIPLE != 0:
print(f"Warning: Specified --target_resolution ({args.target_resolution}) is not divisible by {DIMENSION_MULTIPLE}. The calculated dimensions will be rounded down.")
if args.end_frame and not os.path.exists(args.end_frame):
print(f"Error: End frame image not found at '{args.end_frame}'.")
sys.exit(1)
args.section_data = parse_section_args(args.section)
os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(args.hf_home))
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
return args
def load_models(args):
"""Loads all necessary models."""
print("Loading models...")
if args.device:
device = torch.device(args.device)
else:
device = torch.device(gpu if torch.cuda.is_available() else cpu)
print(f"Using device: {device}")
print(" Loading Text Encoder 1 (Llama)...")
text_encoder = LlamaModel.from_pretrained(args.text_encoder_path, subfolder='text_encoder', torch_dtype=torch.float16).cpu()
print(" Loading Text Encoder 2 (CLIP)...")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
print(" Loading Tokenizer 1 (Llama)...")
tokenizer = LlamaTokenizerFast.from_pretrained(args.text_encoder_path, subfolder='tokenizer')
print(" Loading Tokenizer 2 (CLIP)...")
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path, subfolder='tokenizer_2')
print(" Loading VAE...")
vae = AutoencoderKLHunyuanVideo.from_pretrained(args.vae_path, subfolder='vae', torch_dtype=torch.float16).cpu()
print(" Loading Image Feature Extractor (SigLIP)...")
feature_extractor = SiglipImageProcessor.from_pretrained(args.image_encoder_path, subfolder='feature_extractor')
print(" Loading Image Encoder (SigLIP)...")
image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder_path, subfolder='image_encoder', torch_dtype=torch.float16).cpu()
print(" Loading Transformer (FramePack)...")
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(args.transformer_path, torch_dtype=torch.bfloat16).cpu()
vae.eval()
text_encoder.eval()
text_encoder_2.eval()
image_encoder.eval()
transformer.eval()
transformer.high_quality_fp32_output_for_inference = True
print('transformer.high_quality_fp32_output_for_inference = True')
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
transformer.requires_grad_(False)
print("Models loaded.")
return {
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"vae": vae,
"feature_extractor": feature_extractor,
"image_encoder": image_encoder,
"transformer": transformer,
"device": device
}
def adjust_to_multiple(value, multiple):
"""Rounds down value to the nearest multiple."""
return (value // multiple) * multiple
def mix_latents(latent_a, latent_b, weight_b):
"""Mix two latents with the specified weight for latent_b."""
if latent_a is None: return latent_b
if latent_b is None: return latent_a
target_device = latent_a.device
target_dtype = latent_a.dtype
if latent_b.device != target_device:
latent_b = latent_b.to(target_device)
if latent_b.dtype != target_dtype:
latent_b = latent_b.to(dtype=target_dtype)
if isinstance(weight_b, torch.Tensor):
weight_b = weight_b.item()
weight_b = max(0.0, min(1.0, weight_b))
if weight_b == 0.0:
return latent_a
elif weight_b == 1.0:
return latent_b
else:
return (1.0 - weight_b) * latent_a + weight_b * latent_b
def mix_embeddings(embed_a, embed_b, weight_b):
"""Mix two embedding tensors (like CLIP image embeddings) with the specified weight for embed_b."""
if embed_a is None: return embed_b
if embed_b is None: return embed_a
target_device = embed_a.device
target_dtype = embed_a.dtype
if embed_b.device != target_device:
embed_b = embed_b.to(target_device)
if embed_b.dtype != target_dtype:
embed_b = embed_b.to(dtype=target_dtype)
if isinstance(weight_b, torch.Tensor):
weight_b = weight_b.item()
weight_b = max(0.0, min(1.0, weight_b))
if weight_b == 0.0:
return embed_a
elif weight_b == 1.0:
return embed_b
else:
return (1.0 - weight_b) * embed_a + weight_b * embed_b
def preprocess_image_for_generation(image_path, target_width, target_height, job_id, output_dir, frame_name="input"):
"""Loads, processes, and saves a single image."""
try:
image = Image.open(image_path).convert('RGB')
image_np = np.array(image)
except Exception as e:
print(f"Error loading image '{image_path}': {e}")
raise
H_orig, W_orig, _ = image_np.shape
print(f" {frame_name.capitalize()} image loaded ({W_orig}x{H_orig}): '{image_path}'")
image_resized_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height)
try:
Image.fromarray(image_resized_np).save(output_dir / f'{job_id}_{frame_name}_resized_{target_width}x{target_height}.png')
except Exception as e:
print(f"Warning: Could not save resized image preview for {frame_name}: {e}")
image_pt = torch.from_numpy(image_resized_np).float() / 127.5 - 1.0
image_pt = image_pt.permute(2, 0, 1)[None, :, None] # B=1, C=3, T=1, H, W
print(f" {frame_name.capitalize()} image processed to tensor shape: {image_pt.shape}")
return image_np, image_resized_np, image_pt
@torch.no_grad()
def generate_video(args, models):
"""Generates the video using the loaded models and arguments."""
# Unpack models
text_encoder = models["text_encoder"]
text_encoder_2 = models["text_encoder_2"]
tokenizer = models["tokenizer"]
tokenizer_2 = models["tokenizer_2"]
vae = models["vae"]
feature_extractor = models["feature_extractor"]
image_encoder = models["image_encoder"]
transformer = models["transformer"]
device = models["device"]
# --- Determine Memory Mode ---
if args.high_vram and args.low_vram:
print("Warning: Both --high_vram and --low_vram specified. Defaulting to auto-detection.")
force_high_vram = force_low_vram = False
else:
force_high_vram = args.high_vram
force_low_vram = args.low_vram
if force_high_vram:
high_vram = True
elif force_low_vram:
high_vram = False
else:
free_mem_gb = get_cuda_free_memory_gb(device) if device.type == 'cuda' else 0
high_vram = free_mem_gb > 60
print(f'Auto-detected Free VRAM {free_mem_gb:.2f} GB -> High-VRAM Mode: {high_vram}')
# --- Configure Models based on VRAM mode ---
if not high_vram:
print("Configuring for Low VRAM mode...")
vae.enable_slicing()
vae.enable_tiling()
print(" Installing DynamicSwap for Transformer...")
DynamicSwapInstaller.install_model(transformer, device=device)
print(" Installing DynamicSwap for Text Encoder 1...")
DynamicSwapInstaller.install_model(text_encoder, device=device)
print("Unloading models from GPU (Low VRAM setup)...")
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
else:
print("Configuring for High VRAM mode (moving models to GPU)...")
text_encoder.to(device)
text_encoder_2.to(device)
image_encoder.to(device)
vae.to(device)
transformer.to(device)
print(" Models moved to GPU.")
# --- Prepare Inputs ---
print("Preparing inputs...")
prompt = args.prompt
n_prompt = args.negative_prompt
seed = args.seed
total_second_length = args.total_second_length
latent_window_size = args.latent_window_size
steps = args.steps
cfg = args.cfg
gs = args.distilled_guidance_scale
rs = args.rs
gpu_memory_preservation = args.gpu_memory_preservation
use_teacache = args.use_teacache
fps = args.fps
end_frame_path = args.end_frame
end_frame_influence = args.end_frame_influence
end_frame_weight = args.end_frame_weight
section_data = args.section_data
save_intermediate = args.save_intermediate_sections
save_section_frames = args.save_section_final_frames
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))
print(f"Calculated total latent sections: {total_latent_sections}")
job_id = generate_timestamp() + f"_seed{seed}"
output_dir = Path(args.save_path)
output_dir.mkdir(parents=True, exist_ok=True)
final_video_path = None
# --- Section Preprocessing Storage ---
section_latents = {}
section_image_embeddings = {} # Still store, might be useful later
section_prompt_embeddings = {}
try:
# --- Text Encoding (Global Prompts) ---
print("Encoding global text prompts...")
if not high_vram:
print(" Low VRAM mode: Loading Text Encoders to GPU...")
fake_diffusers_current_device(text_encoder, device)
load_model_as_complete(text_encoder_2, target_device=device)
print(" Text Encoders loaded.")
global_llama_vec, global_clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
if cfg == 1.0:
print(" CFG scale is 1.0, using zero negative embeddings.")
global_llama_vec_n, global_clip_l_pooler_n = torch.zeros_like(global_llama_vec), torch.zeros_like(global_clip_l_pooler)
else:
print(f" Encoding negative prompt: '{n_prompt}'")
global_llama_vec_n, global_clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
global_llama_vec, global_llama_attention_mask = crop_or_pad_yield_mask(global_llama_vec, length=512)
global_llama_vec_n, global_llama_attention_mask_n = crop_or_pad_yield_mask(global_llama_vec_n, length=512)
print(" Global text encoded and processed.")
# --- Section Text Encoding ---
if section_data:
print("Encoding section-specific prompts...")
for section_index, (img_path, prompt_text) in section_data.items():
if prompt_text:
print(f" Encoding prompt for section {section_index}: '{prompt_text}'")
sec_llama_vec, sec_clip_pooler = encode_prompt_conds(prompt_text, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
sec_llama_vec, _ = crop_or_pad_yield_mask(sec_llama_vec, length=512)
section_prompt_embeddings[section_index] = (
sec_llama_vec.cpu().to(transformer.dtype),
sec_clip_pooler.cpu().to(transformer.dtype)
)
print(f" Section {section_index} prompt encoded and stored on CPU.")
else:
print(f" Section {section_index} has no specific prompt, will use global prompt.")
if not high_vram:
print(" Low VRAM mode: Unloading Text Encoders from GPU...")
unload_complete_models(text_encoder_2)
print(" Text Encoder 2 unloaded.")
# --- Input Image Processing & Dimension Calculation ---
print("Processing input image and determining dimensions...")
try:
input_image_np_orig, _, _ = preprocess_image_for_generation(
args.input_image, 1, 1, job_id, output_dir, "temp_input_orig"
)
except Exception as e:
print(f"Error loading input image '{args.input_image}' for dimension check: {e}")
raise
H_orig, W_orig, _ = input_image_np_orig.shape
print(f" Input image original size: {W_orig}x{H_orig}")
if args.width is not None and args.height is not None:
target_w, target_h = args.width, args.height
print(f" Using explicit target dimensions: {target_w}x{target_h}")
elif args.target_resolution is not None:
print(f" Calculating dimensions based on target resolution for longer side: {args.target_resolution}")
target_h, target_w = find_nearest_bucket(H_orig, W_orig, resolution=args.target_resolution)
print(f" Calculated dimensions (before adjustment): {target_w}x{target_h}")
else:
raise ValueError("Internal Error: Resolution determination failed.")
final_w = adjust_to_multiple(target_w, DIMENSION_MULTIPLE)
final_h = adjust_to_multiple(target_h, DIMENSION_MULTIPLE)
if final_w <= 0 or final_h <= 0:
print(f"Error: Calculated dimensions ({target_w}x{target_h}) resulted in non-positive dimensions after adjusting to be divisible by {DIMENSION_MULTIPLE} ({final_w}x{final_h}).")
raise ValueError("Adjusted dimensions are invalid.")
if final_w != target_w or final_h != target_h:
print(f"Warning: Adjusted dimensions from {target_w}x{target_h} to {final_w}x{final_h} to be divisible by {DIMENSION_MULTIPLE}.")
else:
print(f" Final dimensions confirmed: {final_w}x{final_h}")
width, height = final_w, final_h
if width * height > 1024 * 1024:
print(f"Warning: Target resolution {width}x{height} is large. Ensure you have sufficient VRAM.")
_, input_image_resized_np, input_image_pt = preprocess_image_for_generation(
args.input_image, width, height, job_id, output_dir, "input"
)
end_frame_resized_np = None
end_frame_pt = None
if end_frame_path:
_, end_frame_resized_np, end_frame_pt = preprocess_image_for_generation(
end_frame_path, width, height, job_id, output_dir, "end"
)
section_images_resized_np = {}
section_images_pt = {}
if section_data:
print("Processing section keyframe images...")
for section_index, (img_path, _) in section_data.items():
_, sec_resized_np, sec_pt = preprocess_image_for_generation(
img_path, width, height, job_id, output_dir, f"section{section_index}"
)
section_images_resized_np[section_index] = sec_resized_np
section_images_pt[section_index] = sec_pt
# --- VAE Encoding ---
print("VAE encoding initial frame...")
if not high_vram:
print(" Low VRAM mode: Loading VAE to GPU...")
load_model_as_complete(vae, target_device=device)
print(" VAE loaded.")
input_image_pt_dev = input_image_pt.to(device=device, dtype=vae.dtype)
start_latent = vae_encode(input_image_pt_dev, vae) # GPU, vae.dtype
print(f" Initial latent shape: {start_latent.shape}")
print(f" Start latent stats - Min: {start_latent.min().item():.4f}, Max: {start_latent.max().item():.4f}, Mean: {start_latent.mean().item():.4f}")
end_frame_latent = None
if end_frame_pt is not None:
print("VAE encoding end frame...")
end_frame_pt_dev = end_frame_pt.to(device=device, dtype=vae.dtype)
end_frame_latent = vae_encode(end_frame_pt_dev, vae) # GPU, vae.dtype
print(f" End frame latent shape: {end_frame_latent.shape}")
print(f" End frame latent stats - Min: {end_frame_latent.min().item():.4f}, Max: {end_frame_latent.max().item():.4f}, Mean: {end_frame_latent.mean().item():.4f}")
if end_frame_latent.shape != start_latent.shape:
print(f"Warning: End frame latent shape mismatch. Reshaping.")
try:
end_frame_latent = end_frame_latent.reshape(start_latent.shape)
except Exception as reshape_err:
print(f"Error reshaping end frame latent: {reshape_err}. Disabling end frame.")
end_frame_latent = None
if section_images_pt:
print("VAE encoding section keyframes...")
for section_index, sec_pt in section_images_pt.items():
sec_pt_dev = sec_pt.to(device=device, dtype=vae.dtype)
sec_latent = vae_encode(sec_pt_dev, vae) # GPU, vae.dtype
print(f" Section {section_index} latent shape: {sec_latent.shape}")
if sec_latent.shape != start_latent.shape:
print(f" Warning: Section {section_index} latent shape mismatch. Reshaping.")
try:
sec_latent = sec_latent.reshape(start_latent.shape)
except Exception as reshape_err:
print(f" Error reshaping section {section_index} latent: {reshape_err}. Skipping section latent.")
continue
# Store on CPU as float32 for context/blending later
section_latents[section_index] = sec_latent.cpu().float()
print(f" Section {section_index} latent encoded and stored on CPU.")
if not high_vram:
print(" Low VRAM mode: Unloading VAE from GPU...")
unload_complete_models(vae)
print(" VAE unloaded.")
# Move essential latents to CPU as float32 for context/blending
start_latent = start_latent.cpu().float()
if end_frame_latent is not None:
end_frame_latent = end_frame_latent.cpu().float()
# --- CLIP Vision Encoding ---
print("CLIP Vision encoding image(s)...")
if not high_vram:
print(" Low VRAM mode: Loading Image Encoder to GPU...")
load_model_as_complete(image_encoder, target_device=device)
print(" Image Encoder loaded.")
# Encode start frame - WILL BE USED CONSISTENTLY for image_embeddings
image_encoder_output = hf_clip_vision_encode(input_image_resized_np, feature_extractor, image_encoder)
start_image_embedding = image_encoder_output.last_hidden_state # GPU, image_encoder.dtype
print(f" Start image embedding shape: {start_image_embedding.shape}")
# Encode end frame (if provided) - Only needed if extending later
# end_frame_embedding = None # Not needed for this strategy
# if end_frame_resized_np is not None:
# pass # Skip encoding for now
# Encode section frames (if provided) - Store for potential future use
if section_images_resized_np:
print("CLIP Vision encoding section keyframes (storing on CPU)...")
for section_index, sec_resized_np in section_images_resized_np.items():
sec_output = hf_clip_vision_encode(sec_resized_np, feature_extractor, image_encoder)
sec_embedding = sec_output.last_hidden_state
section_image_embeddings[section_index] = sec_embedding.cpu().to(transformer.dtype)
print(f" Section {section_index} embedding shape: {sec_embedding.shape}. Stored on CPU.")
if not high_vram:
print(" Low VRAM mode: Unloading Image Encoder from GPU...")
unload_complete_models(image_encoder)
print(" Image Encoder unloaded.")
# Move start image embedding to CPU (transformer dtype)
target_dtype = transformer.dtype
start_image_embedding = start_image_embedding.cpu().to(target_dtype)
# --- Prepare Global Embeddings for Transformer (CPU, transformer.dtype) ---
print("Preparing global embeddings for Transformer...")
global_llama_vec = global_llama_vec.cpu().to(target_dtype)
global_llama_vec_n = global_llama_vec_n.cpu().to(target_dtype)
global_clip_l_pooler = global_clip_l_pooler.cpu().to(target_dtype)
global_clip_l_pooler_n = global_clip_l_pooler_n.cpu().to(target_dtype)
print(f" Global Embeddings prepared on CPU with dtype {target_dtype}.")
# --- Sampling Setup ---
print("Setting up sampling...")
rnd = torch.Generator(cpu).manual_seed(seed)
num_frames = latent_window_size * 4 - 3
print(f" Latent frames per sampling step (num_frames input): {num_frames}")
latent_c, latent_h, latent_w = start_latent.shape[1], start_latent.shape[3], start_latent.shape[4]
context_latents = torch.zeros(size=(1, latent_c, 1 + 2 + 16, latent_h, latent_w), dtype=torch.float32).cpu()
accumulated_generated_latents = None
history_pixels = None
latent_paddings = list(reversed(range(total_latent_sections)))
if total_latent_sections > 4:
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
print(f" Using adjusted padding sequence for >4 sections: {latent_paddings}")
else:
print(f" Using standard padding sequence: {latent_paddings}")
# --- [MODIFIED] Restore Initial Context Initialization ---
if end_frame_latent is not None:
print(" Initializing context buffer's first slot with end frame latent.")
context_latents[:, :, 0:1, :, :] = end_frame_latent.cpu().float() # Ensure float32 CPU
else:
print(" No end frame latent available. Initial context remains zeros.")
# --- End Modified Context Initialization ---
# --- Main Sampling Loop (Generates Backward: End -> Start) ---
start_time = time.time()
num_loops = len(latent_paddings)
for i_loop, latent_padding in enumerate(latent_paddings):
section_start_time = time.time()
current_section_index_from_end = latent_padding
is_first_generation_step = (i_loop == 0)
is_last_generation_step = (latent_padding == 0)
print(f"\n--- Starting Generation Step {i_loop+1}/{num_loops} (Section Index from End: {current_section_index_from_end}, First Step: {is_first_generation_step}, Last Step: {is_last_generation_step}) ---")
latent_padding_size = latent_padding * latent_window_size
print(f' Padding size (latent frames): {latent_padding_size}, Window size (latent frames): {latent_window_size}')
# --- Select Conditioning Inputs for this Section ---
# 1. Conditioning Latent (`clean_latents_pre`) - Calculate Blend
# Determine the base latent (start or section-specific)
base_conditioning_latent = start_latent # Default to start (float32 CPU)
if current_section_index_from_end in section_latents:
base_conditioning_latent = section_latents[current_section_index_from_end] # Use section if available (float32 CPU)
print(f" Using SECTION {current_section_index_from_end} latent as base conditioning latent.")
else:
print(f" Using START frame latent as base conditioning latent.")
# Apply 'bookend' override to the base latent for the first step only
if end_frame_influence == "bookend" and is_first_generation_step and end_frame_latent is not None:
if current_section_index_from_end not in section_latents:
base_conditioning_latent = end_frame_latent # float32 CPU
print(" Applying 'bookend': Overriding base conditioning latent with END frame latent for first step.")
# Blend the base conditioning latent with the end frame latent based on mode/weight
current_conditioning_latent = base_conditioning_latent # Initialize with base
current_end_frame_latent_weight = 0.0
if end_frame_latent is not None: # Only blend if end frame exists
if end_frame_influence == 'progressive':
progress = i_loop / max(1, num_loops - 1)
current_end_frame_latent_weight = args.end_frame_weight * (1.0 - progress)
elif end_frame_influence == 'half':
if i_loop < num_loops / 2:
current_end_frame_latent_weight = args.end_frame_weight
# For 'last' and 'bookend', weight remains 0, no blending needed
current_end_frame_latent_weight = max(0.0, min(1.0, current_end_frame_latent_weight))
if current_end_frame_latent_weight > 1e-4: # Mix only if weight is significant
print(f" Blending Conditioning Latent: Base<-{1.0-current_end_frame_latent_weight:.3f} | End->{current_end_frame_latent_weight:.3f} (Mode: {end_frame_influence})")
# Ensure both inputs to mix_latents are float32 CPU
current_conditioning_latent = mix_latents(base_conditioning_latent.cpu().float(),
end_frame_latent.cpu().float(),
current_end_frame_latent_weight)
#else:
# print(f" Using BASE conditioning latent (Mode: {end_frame_influence}, Blend Weight near zero).") # Can be verbose
#else:
# print(f" Using BASE conditioning latent (No end frame specified for blending).") # Can be verbose
# 2. Image Embedding - Use Fixed Start Embedding
current_image_embedding = start_image_embedding # transformer.dtype CPU
print(f" Using fixed START frame image embedding.")
# 3. Text Embedding (Select section or global)
if current_section_index_from_end in section_prompt_embeddings:
current_llama_vec, current_clip_pooler = section_prompt_embeddings[current_section_index_from_end]
print(f" Using SECTION {current_section_index_from_end} prompt embeddings.")
else:
current_llama_vec = global_llama_vec
current_clip_pooler = global_clip_l_pooler
print(f" Using GLOBAL prompt embeddings.")
current_llama_vec_n = global_llama_vec_n
current_clip_pooler_n = global_clip_l_pooler_n
current_llama_attention_mask = global_llama_attention_mask
current_llama_attention_mask_n = global_llama_attention_mask_n
# --- Prepare Sampler Inputs ---
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = \
indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
# Prepare conditioning latents (float32 CPU)
clean_latents_pre = current_conditioning_latent # Use the potentially blended one
clean_latents_post, clean_latents_2x, clean_latents_4x = \
context_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
print(f" Final Conditioning shapes (CPU): clean={clean_latents.shape}, 2x={clean_latents_2x.shape}, 4x={clean_latents_4x.shape}")
print(f" Clean Latents Pre stats - Min: {clean_latents_pre.min().item():.4f}, Max: {clean_latents_pre.max().item():.4f}, Mean: {clean_latents_pre.mean().item():.4f}")
# Load Transformer (Low VRAM)
if not high_vram:
print(" Moving Transformer to GPU...")
unload_complete_models()
move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
fake_diffusers_current_device(text_encoder, device)
# Configure TeaCache
if use_teacache:
transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
print(" TeaCache enabled.")
else:
transformer.initialize_teacache(enable_teacache=False)
print(" TeaCache disabled.")
# --- Run Sampling ---
print(f" Starting sampling ({steps} steps) for {num_frames} latent frames...")
sampling_step_start_time = time.time()
pbar = tqdm(total=steps, desc=f" Section {current_section_index_from_end} Sampling", leave=False)
def callback(d):
pbar.update(1)
return
current_sampler_device = transformer.device
current_text_encoder_device = text_encoder.device if not high_vram else device
# Move tensors to device just before sampling
_prompt_embeds = current_llama_vec.to(current_text_encoder_device)
_prompt_embeds_mask = current_llama_attention_mask.to(current_text_encoder_device)
_prompt_poolers = current_clip_pooler.to(current_sampler_device)
_negative_prompt_embeds = current_llama_vec_n.to(current_text_encoder_device)
_negative_prompt_embeds_mask = current_llama_attention_mask_n.to(current_text_encoder_device)
_negative_prompt_poolers = current_clip_pooler_n.to(current_sampler_device)
_image_embeddings = current_image_embedding.to(current_sampler_device) # Fixed start embedding
_latent_indices = latent_indices.to(current_sampler_device)
# Pass conditioning latents (now potentially blended) to sampler
_clean_latents = clean_latents.to(current_sampler_device, dtype=transformer.dtype)
_clean_latent_indices = clean_latent_indices.to(current_sampler_device)
_clean_latents_2x = clean_latents_2x.to(current_sampler_device, dtype=transformer.dtype)
_clean_latent_2x_indices = clean_latent_2x_indices.to(current_sampler_device)
_clean_latents_4x = clean_latents_4x.to(current_sampler_device, dtype=transformer.dtype)
_clean_latent_4x_indices = clean_latent_4x_indices.to(current_sampler_device)
generated_latents_gpu = sample_hunyuan(
transformer=transformer,
sampler='unipc',
width=width,
height=height,
frames=num_frames,
real_guidance_scale=cfg,
distilled_guidance_scale=gs,
guidance_rescale=rs,
num_inference_steps=steps,
generator=rnd,
prompt_embeds=_prompt_embeds,
prompt_embeds_mask=_prompt_embeds_mask,
prompt_poolers=_prompt_poolers,
negative_prompt_embeds=_negative_prompt_embeds,
negative_prompt_embeds_mask=_negative_prompt_embeds_mask,
negative_prompt_poolers=_negative_prompt_poolers,
device=current_sampler_device,
dtype=transformer.dtype,
image_embeddings=_image_embeddings, # Using fixed start embedding
latent_indices=_latent_indices,
clean_latents=_clean_latents, # Using potentially blended latents
clean_latent_indices=_clean_latent_indices,
clean_latents_2x=_clean_latents_2x,
clean_latent_2x_indices=_clean_latent_2x_indices,
clean_latents_4x=_clean_latents_4x,
clean_latent_4x_indices=_clean_latent_4x_indices,
callback=callback,
)
pbar.close()
sampling_step_end_time = time.time()
print(f" Sampling finished in {sampling_step_end_time - sampling_step_start_time:.2f} seconds.")
print(f" Raw generated latent shape for this step: {generated_latents_gpu.shape}")
print(f" Generated latents stats (GPU) - Min: {generated_latents_gpu.min().item():.4f}, Max: {generated_latents_gpu.max().item():.4f}, Mean: {generated_latents_gpu.mean().item():.4f}")
# Move generated latents to CPU as float32
generated_latents_cpu = generated_latents_gpu.cpu().float()
del generated_latents_gpu, _prompt_embeds, _prompt_embeds_mask, _prompt_poolers, _negative_prompt_embeds, _negative_prompt_embeds_mask, _negative_prompt_poolers
del _image_embeddings, _latent_indices, _clean_latents, _clean_latent_indices, _clean_latents_2x, _clean_latent_2x_indices, _clean_latents_4x, _clean_latent_4x_indices
if device.type == 'cuda': torch.cuda.empty_cache()
# Offload Transformer and TE1 (Low VRAM)
if not high_vram:
print(" Low VRAM mode: Offloading Transformer and Text Encoder from GPU...")
offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
offload_model_from_device_for_memory_preservation(text_encoder, target_device=device, preserved_memory_gb=gpu_memory_preservation)
print(" Transformer and Text Encoder offloaded.")
# --- History/Context Update ---
if is_last_generation_step:
print(" Last generation step: Prepending start frame latent to generated latents.")
generated_latents_cpu = torch.cat([start_latent.cpu().float(), generated_latents_cpu], dim=2)
print(f" Shape after prepending start latent: {generated_latents_cpu.shape}")
context_latents = torch.cat([generated_latents_cpu, context_latents], dim=2)
print(f" Context buffer updated. New shape: {context_latents.shape}")
# Accumulate the generated latents for the final video output
if accumulated_generated_latents is None:
accumulated_generated_latents = generated_latents_cpu
else:
accumulated_generated_latents = torch.cat([generated_latents_cpu, accumulated_generated_latents], dim=2)
current_total_latent_frames = accumulated_generated_latents.shape[2]
print(f" Accumulated generated latents updated. Total latent frames: {current_total_latent_frames}")
print(f" Accumulated latents stats - Min: {accumulated_generated_latents.min().item():.4f}, Max: {accumulated_generated_latents.max().item():.4f}, Mean: {accumulated_generated_latents.mean().item():.4f}")
# --- VAE Decoding & Merging ---
print(" Decoding generated latents and merging video...")
decode_start_time = time.time()
if not high_vram:
print(" Moving VAE to GPU...")
offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
unload_complete_models(text_encoder, text_encoder_2, image_encoder)
load_model_as_complete(vae, target_device=device)
print(" VAE loaded.")
print(f" Decoding current section's latents (shape: {generated_latents_cpu.shape}) for append.")
latents_to_decode_for_append = generated_latents_cpu.to(device=device, dtype=vae.dtype)
current_pixels = vae_decode(latents_to_decode_for_append, vae).cpu().float() # Decode and move to CPU float32
print(f" Decoded pixels for append shape: {current_pixels.shape}")
del latents_to_decode_for_append
if device.type == 'cuda': torch.cuda.empty_cache()
if history_pixels is None:
history_pixels = current_pixels
print(f" Initialized history_pixels shape: {history_pixels.shape}")
else:
append_overlap = 3
print(f" Appending section with pixel overlap: {append_overlap}")
history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlap=append_overlap)
print(f" Appended. New total pixel shape: {history_pixels.shape}")
if not high_vram:
print(" Low VRAM mode: Unloading VAE from GPU...")
unload_complete_models(vae)
print(" VAE unloaded.")
decode_end_time = time.time()
print(f" Decoding and merging finished in {decode_end_time - decode_start_time:.2f} seconds.")
# --- Save Intermediate/Section Output ---
current_num_pixel_frames = history_pixels.shape[2]
if save_section_frames:
try:
first_frame_index = 0 # Index 0 of the newly decoded chunk is the first frame generated in this step
frame_to_save = current_pixels[0, :, first_frame_index, :, :]
frame_to_save = einops.rearrange(frame_to_save, 'c h w -> h w c')
frame_to_save_np = frame_to_save.cpu().numpy()
frame_to_save_np = np.clip((frame_to_save_np * 127.5 + 127.5), 0, 255).astype(np.uint8)
section_frame_filename = output_dir / f'{job_id}_section_start_frame_idx{current_section_index_from_end}.png' # Renamed for clarity
Image.fromarray(frame_to_save_np).save(section_frame_filename)
print(f" Saved first generated pixel frame of section {current_section_index_from_end} (from decoded chunk) to: {section_frame_filename}")
except Exception as e:
print(f" [WARN] Error saving section {current_section_index_from_end} start frame image: {e}")
if save_intermediate or is_last_generation_step:
output_filename = output_dir / f'{job_id}_step{i_loop+1}_idx{current_section_index_from_end}_frames{current_num_pixel_frames}_{width}x{height}.mp4'
print(f" Saving {'intermediate' if not is_last_generation_step else 'final'} video ({current_num_pixel_frames} frames) to: {output_filename}")
try:
save_bcthw_as_mp4(history_pixels.float(), str(output_filename), fps=int(fps))
print(f" Saved video using save_bcthw_as_mp4")
if not is_last_generation_step:
print(f"INTERMEDIATE_VIDEO_PATH:{output_filename}")
final_video_path = str(output_filename)
except Exception as e:
print(f" Error saving video using save_bcthw_as_mp4: {e}")
traceback.print_exc()
# Fallback save attempt
try:
first_frame_img = history_pixels.float()[0, :, 0].permute(1, 2, 0).cpu().numpy()
first_frame_img = (first_frame_img * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
frame_path = str(output_filename).replace('.mp4', '_first_frame_ERROR.png')
Image.fromarray(first_frame_img).save(frame_path)
print(f" Saved first frame as image to {frame_path} due to video saving error.")
except Exception as frame_err:
print(f" Could not save first frame either: {frame_err}")
section_end_time = time.time()
print(f"--- Generation Step {i_loop+1} finished in {section_end_time - section_start_time:.2f} seconds ---")
if is_last_generation_step:
print("\nFinal generation step completed.")
break
# --- Final Video Saved During Last Step ---
if final_video_path and os.path.exists(final_video_path):
print(f"\nSuccessfully generated: {final_video_path}")
print(f"ACTUAL_FINAL_PATH:{final_video_path}")
return final_video_path
else:
print("\nError: Final video path not found or not saved correctly.")
return None
except Exception as e:
print("\n--- ERROR DURING GENERATION ---")
traceback.print_exc()
print("-----------------------------")
if 'history_pixels' in locals() and history_pixels is not None and history_pixels.shape[2] > 0:
partial_output_name = output_dir / f"{job_id}_partial_ERROR_{history_pixels.shape[2]}_frames_{width}x{height}.mp4"
print(f"Attempting to save partial video to: {partial_output_name}")
try:
save_bcthw_as_mp4(history_pixels.float(), str(partial_output_name), fps=fps)
print(f"ACTUAL_FINAL_PATH:{partial_output_name}")
return str(partial_output_name)
except Exception as save_err:
print(f"Error saving partial video during error handling: {save_err}")
traceback.print_exc()
print("Status: Error occurred, no video saved.")
return None
finally:
print("Performing final model cleanup...")
try:
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
except Exception as e:
print(f"Error during final model unload: {e}")
pass
if device.type == 'cuda':
torch.cuda.empty_cache()
print("CUDA cache cleared.")
def main():
args = parse_args()
models = load_models(args)
final_path = generate_video(args, models)
if final_path:
print(f"\nVideo generation finished. Final path: {final_path}")
sys.exit(0)
else:
print("\nVideo generation failed.")
sys.exit(1)
if __name__ == "__main__":
main()