Spaces:
Running
Running
# Combined and Corrected Script | |
#!/usr/bin/env python3 | |
import argparse | |
import os | |
import sys | |
import time | |
import random | |
import traceback | |
from datetime import datetime | |
from pathlib import Path | |
import re # For parsing section args | |
import einops | |
import numpy as np | |
import torch | |
import av # For saving video (used by save_bcthw_as_mp4) | |
from PIL import Image | |
from tqdm import tqdm | |
import cv2 | |
# --- Dependencies from diffusers_helper --- | |
# Ensure this library is installed or in the PYTHONPATH | |
try: | |
# from diffusers_helper.hf_login import login # Not strictly needed for inference if models public/cached | |
from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode #, vae_decode_fake # vae_decode_fake not used here | |
from diffusers_helper.utils import (save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, | |
resize_and_center_crop, generate_timestamp) | |
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked | |
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan | |
from diffusers_helper.memory import (cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, | |
offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, | |
DynamicSwapInstaller, unload_complete_models, load_model_as_complete) | |
from diffusers_helper.clip_vision import hf_clip_vision_encode | |
from diffusers_helper.bucket_tools import find_nearest_bucket#, bucket_options # bucket_options no longer needed here | |
except ImportError: | |
print("Error: Could not import modules from 'diffusers_helper'.") | |
print("Please ensure the 'diffusers_helper' library is installed and accessible.") | |
print("You might need to clone the repository and add it to your PYTHONPATH.") | |
sys.exit(1) | |
# --- End Dependencies --- | |
from diffusers import AutoencoderKLHunyuanVideo | |
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer | |
from transformers import SiglipImageProcessor, SiglipVisionModel | |
# --- Constants --- | |
DIMENSION_MULTIPLE = 16 # VAE and model constraints often require divisibility by 8 or 16. 16 is safer. | |
SECTION_ARG_PATTERN = re.compile(r"^(\d+):([^:]+)(?::(.*))?$") # Regex for section arg: number:image_path[:prompt] | |
def parse_section_args(section_strings): | |
""" Parses the --section arguments into a dictionary. """ | |
section_data = {} | |
if not section_strings: | |
return section_data | |
for section_str in section_strings: | |
match = SECTION_ARG_PATTERN.match(section_str) | |
if not match: | |
print(f"Warning: Invalid section format: '{section_str}'. Expected 'number:image_path[:prompt]'. Skipping.") | |
continue | |
section_index_str, image_path, prompt_text = match.groups() | |
section_index = int(section_index_str) | |
prompt_text = prompt_text if prompt_text else None | |
if not os.path.exists(image_path): | |
print(f"Warning: Image path for section {section_index} ('{image_path}') not found. Skipping section.") | |
continue | |
if section_index in section_data: | |
print(f"Warning: Duplicate section index {section_index}. Overwriting previous entry.") | |
section_data[section_index] = (image_path, prompt_text) | |
print(f"Parsed section {section_index}: Image='{image_path}', Prompt='{prompt_text}'") | |
return section_data | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="FramePack HunyuanVideo inference script (CLI version with Advanced End Frame & Section Control)") | |
# --- Model Paths --- | |
parser.add_argument('--transformer_path', type=str, default='lllyasviel/FramePackI2V_HY', help="Path to the FramePack Transformer model") | |
parser.add_argument('--vae_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the VAE model directory") | |
parser.add_argument('--text_encoder_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the Llama text encoder directory") | |
parser.add_argument('--text_encoder_2_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the CLIP text encoder directory") | |
parser.add_argument('--image_encoder_path', type=str, default='lllyasviel/flux_redux_bfl', help="Path to the SigLIP image encoder directory") | |
parser.add_argument('--hf_home', type=str, default='./hf_download', help="Directory to download/cache Hugging Face models") | |
# --- Input --- | |
parser.add_argument("--input_image", type=str, required=True, help="Path to the input image (start frame)") | |
parser.add_argument("--end_frame", type=str, default=None, help="Path to the optional end frame image (video end)") | |
parser.add_argument("--prompt", type=str, required=True, help="Default prompt for generation") | |
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for generation") | |
# <<< START: Modified Arguments for End Frame >>> | |
parser.add_argument("--end_frame_weight", type=float, default=0.3, help="End frame influence weight (0.0-1.0) for blending modes ('half', 'progressive'). Higher blends more end frame *conditioning latent*.") # Default lowered further | |
parser.add_argument("--end_frame_influence", type=str, default="last", | |
choices=["last", "half", "progressive", "bookend"], | |
help="How to use the global end frame: 'last' (uses end frame for initial context only, no latent blending), 'half' (blends start/end conditioning latents for second half of video), 'progressive' (gradually blends conditioning latents from end to start), 'bookend' (uses end frame conditioning latent ONLY for first generated section IF no section keyframe set, no blending otherwise). All modes use start image embedding.") # Help text updated | |
# <<< END: Modified Arguments for End Frame >>> | |
# <<< START: New Arguments for Section Control >>> | |
parser.add_argument("--section", type=str, action='append', | |
help="Define a keyframe section. Format: 'index:image_path[:prompt]'. Index 0 is the last generated section (video start), 1 is second last, etc. Repeat for multiple sections. Example: --section 0:path/to/start_like.png:'A sunrise' --section 2:path/to/mid.png") | |
# <<< END: New Arguments for Section Control >>> | |
# --- Output Resolution (Choose ONE method) --- | |
parser.add_argument("--target_resolution", type=int, default=None, help=f"Target resolution for the longer side for automatic aspect ratio calculation (bucketing). Used if --width and --height are not specified. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.") | |
parser.add_argument("--width", type=int, default=None, help=f"Explicit target width for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.") | |
parser.add_argument("--height", type=int, default=None, help=f"Explicit target height for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.") | |
# --- Output --- | |
parser.add_argument("--save_path", type=str, required=True, help="Directory to save the generated video") | |
parser.add_argument("--save_intermediate_sections", action='store_true', help="Save the video after each section is generated and decoded.") | |
parser.add_argument("--save_section_final_frames", action='store_true', help="Save the final decoded frame of each generated section as a PNG image.") | |
# --- Generation Parameters (Matching Gradio Demo Defaults where applicable) --- | |
parser.add_argument("--seed", type=int, default=None, help="Seed for generation. Random if not set.") | |
parser.add_argument("--total_second_length", type=float, default=5.0, help="Total desired video length in seconds") | |
parser.add_argument("--fps", type=int, default=30, help="Frames per second for the output video") | |
parser.add_argument("--steps", type=int, default=25, help="Number of inference steps (changing not recommended)") | |
parser.add_argument("--distilled_guidance_scale", "--gs", type=float, default=10.0, help="Distilled CFG Scale (gs)") | |
parser.add_argument("--cfg", type=float, default=1.0, help="Classifier-Free Guidance Scale (fixed at 1.0 for FramePack usually)") | |
parser.add_argument("--rs", type=float, default=0.0, help="CFG Rescale (fixed at 0.0 for FramePack usually)") | |
parser.add_argument("--latent_window_size", type=int, default=9, help="Latent window size (changing not recommended)") | |
# --- Performance / Memory --- | |
parser.add_argument('--high_vram', action='store_true', help="Force high VRAM mode (loads all models to GPU)") | |
parser.add_argument('--low_vram', action='store_true', help="Force low VRAM mode (uses dynamic swapping)") | |
parser.add_argument("--gpu_memory_preservation", type=float, default=6.0, help="GPU memory (GB) to preserve when offloading (low VRAM mode)") | |
parser.add_argument('--use_teacache', action='store_true', default=True, help="Use TeaCache optimization (default: True)") | |
parser.add_argument('--no_teacache', action='store_false', dest='use_teacache', help="Disable TeaCache optimization") | |
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda', 'cpu'). Auto-detects if None.") | |
args = parser.parse_args() | |
# --- Argument Validation --- | |
if args.seed is None: | |
args.seed = random.randint(0, 2**32 - 1) | |
print(f"Generated random seed: {args.seed}") | |
if args.width is not None and args.height is not None: | |
if args.width <= 0 or args.height <= 0: | |
print(f"Error: Explicit --width ({args.width}) and --height ({args.height}) must be positive.") | |
sys.exit(1) | |
if args.target_resolution is not None: | |
print("Warning: Both --width/--height and --target_resolution specified. Using explicit --width and --height.") | |
args.target_resolution = None | |
elif args.target_resolution is not None: | |
if args.target_resolution <= 0: | |
print(f"Error: --target_resolution ({args.target_resolution}) must be positive.") | |
sys.exit(1) | |
if args.width is not None or args.height is not None: | |
print("Error: Cannot specify --target_resolution with only one of --width or --height. Provide both or neither.") | |
sys.exit(1) | |
else: | |
print(f"Warning: No resolution specified. Defaulting to --target_resolution 640.") | |
args.target_resolution = 640 | |
if args.end_frame_weight < 0.0 or args.end_frame_weight > 1.0: | |
print(f"Error: --end_frame_weight must be between 0.0 and 1.0 (got {args.end_frame_weight}).") | |
sys.exit(1) | |
if args.width is not None and args.width % DIMENSION_MULTIPLE != 0: | |
print(f"Warning: Specified --width ({args.width}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.") | |
if args.height is not None and args.height % DIMENSION_MULTIPLE != 0: | |
print(f"Warning: Specified --height ({args.height}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.") | |
if args.target_resolution is not None and args.target_resolution % DIMENSION_MULTIPLE != 0: | |
print(f"Warning: Specified --target_resolution ({args.target_resolution}) is not divisible by {DIMENSION_MULTIPLE}. The calculated dimensions will be rounded down.") | |
if args.end_frame and not os.path.exists(args.end_frame): | |
print(f"Error: End frame image not found at '{args.end_frame}'.") | |
sys.exit(1) | |
args.section_data = parse_section_args(args.section) | |
os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(args.hf_home)) | |
os.makedirs(os.environ['HF_HOME'], exist_ok=True) | |
return args | |
def load_models(args): | |
"""Loads all necessary models.""" | |
print("Loading models...") | |
if args.device: | |
device = torch.device(args.device) | |
else: | |
device = torch.device(gpu if torch.cuda.is_available() else cpu) | |
print(f"Using device: {device}") | |
print(" Loading Text Encoder 1 (Llama)...") | |
text_encoder = LlamaModel.from_pretrained(args.text_encoder_path, subfolder='text_encoder', torch_dtype=torch.float16).cpu() | |
print(" Loading Text Encoder 2 (CLIP)...") | |
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, subfolder='text_encoder_2', torch_dtype=torch.float16).cpu() | |
print(" Loading Tokenizer 1 (Llama)...") | |
tokenizer = LlamaTokenizerFast.from_pretrained(args.text_encoder_path, subfolder='tokenizer') | |
print(" Loading Tokenizer 2 (CLIP)...") | |
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path, subfolder='tokenizer_2') | |
print(" Loading VAE...") | |
vae = AutoencoderKLHunyuanVideo.from_pretrained(args.vae_path, subfolder='vae', torch_dtype=torch.float16).cpu() | |
print(" Loading Image Feature Extractor (SigLIP)...") | |
feature_extractor = SiglipImageProcessor.from_pretrained(args.image_encoder_path, subfolder='feature_extractor') | |
print(" Loading Image Encoder (SigLIP)...") | |
image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder_path, subfolder='image_encoder', torch_dtype=torch.float16).cpu() | |
print(" Loading Transformer (FramePack)...") | |
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(args.transformer_path, torch_dtype=torch.bfloat16).cpu() | |
vae.eval() | |
text_encoder.eval() | |
text_encoder_2.eval() | |
image_encoder.eval() | |
transformer.eval() | |
transformer.high_quality_fp32_output_for_inference = True | |
print('transformer.high_quality_fp32_output_for_inference = True') | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
text_encoder_2.requires_grad_(False) | |
image_encoder.requires_grad_(False) | |
transformer.requires_grad_(False) | |
print("Models loaded.") | |
return { | |
"text_encoder": text_encoder, | |
"text_encoder_2": text_encoder_2, | |
"tokenizer": tokenizer, | |
"tokenizer_2": tokenizer_2, | |
"vae": vae, | |
"feature_extractor": feature_extractor, | |
"image_encoder": image_encoder, | |
"transformer": transformer, | |
"device": device | |
} | |
def adjust_to_multiple(value, multiple): | |
"""Rounds down value to the nearest multiple.""" | |
return (value // multiple) * multiple | |
def mix_latents(latent_a, latent_b, weight_b): | |
"""Mix two latents with the specified weight for latent_b.""" | |
if latent_a is None: return latent_b | |
if latent_b is None: return latent_a | |
target_device = latent_a.device | |
target_dtype = latent_a.dtype | |
if latent_b.device != target_device: | |
latent_b = latent_b.to(target_device) | |
if latent_b.dtype != target_dtype: | |
latent_b = latent_b.to(dtype=target_dtype) | |
if isinstance(weight_b, torch.Tensor): | |
weight_b = weight_b.item() | |
weight_b = max(0.0, min(1.0, weight_b)) | |
if weight_b == 0.0: | |
return latent_a | |
elif weight_b == 1.0: | |
return latent_b | |
else: | |
return (1.0 - weight_b) * latent_a + weight_b * latent_b | |
def mix_embeddings(embed_a, embed_b, weight_b): | |
"""Mix two embedding tensors (like CLIP image embeddings) with the specified weight for embed_b.""" | |
if embed_a is None: return embed_b | |
if embed_b is None: return embed_a | |
target_device = embed_a.device | |
target_dtype = embed_a.dtype | |
if embed_b.device != target_device: | |
embed_b = embed_b.to(target_device) | |
if embed_b.dtype != target_dtype: | |
embed_b = embed_b.to(dtype=target_dtype) | |
if isinstance(weight_b, torch.Tensor): | |
weight_b = weight_b.item() | |
weight_b = max(0.0, min(1.0, weight_b)) | |
if weight_b == 0.0: | |
return embed_a | |
elif weight_b == 1.0: | |
return embed_b | |
else: | |
return (1.0 - weight_b) * embed_a + weight_b * embed_b | |
def preprocess_image_for_generation(image_path, target_width, target_height, job_id, output_dir, frame_name="input"): | |
"""Loads, processes, and saves a single image.""" | |
try: | |
image = Image.open(image_path).convert('RGB') | |
image_np = np.array(image) | |
except Exception as e: | |
print(f"Error loading image '{image_path}': {e}") | |
raise | |
H_orig, W_orig, _ = image_np.shape | |
print(f" {frame_name.capitalize()} image loaded ({W_orig}x{H_orig}): '{image_path}'") | |
image_resized_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height) | |
try: | |
Image.fromarray(image_resized_np).save(output_dir / f'{job_id}_{frame_name}_resized_{target_width}x{target_height}.png') | |
except Exception as e: | |
print(f"Warning: Could not save resized image preview for {frame_name}: {e}") | |
image_pt = torch.from_numpy(image_resized_np).float() / 127.5 - 1.0 | |
image_pt = image_pt.permute(2, 0, 1)[None, :, None] # B=1, C=3, T=1, H, W | |
print(f" {frame_name.capitalize()} image processed to tensor shape: {image_pt.shape}") | |
return image_np, image_resized_np, image_pt | |
def generate_video(args, models): | |
"""Generates the video using the loaded models and arguments.""" | |
# Unpack models | |
text_encoder = models["text_encoder"] | |
text_encoder_2 = models["text_encoder_2"] | |
tokenizer = models["tokenizer"] | |
tokenizer_2 = models["tokenizer_2"] | |
vae = models["vae"] | |
feature_extractor = models["feature_extractor"] | |
image_encoder = models["image_encoder"] | |
transformer = models["transformer"] | |
device = models["device"] | |
# --- Determine Memory Mode --- | |
if args.high_vram and args.low_vram: | |
print("Warning: Both --high_vram and --low_vram specified. Defaulting to auto-detection.") | |
force_high_vram = force_low_vram = False | |
else: | |
force_high_vram = args.high_vram | |
force_low_vram = args.low_vram | |
if force_high_vram: | |
high_vram = True | |
elif force_low_vram: | |
high_vram = False | |
else: | |
free_mem_gb = get_cuda_free_memory_gb(device) if device.type == 'cuda' else 0 | |
high_vram = free_mem_gb > 60 | |
print(f'Auto-detected Free VRAM {free_mem_gb:.2f} GB -> High-VRAM Mode: {high_vram}') | |
# --- Configure Models based on VRAM mode --- | |
if not high_vram: | |
print("Configuring for Low VRAM mode...") | |
vae.enable_slicing() | |
vae.enable_tiling() | |
print(" Installing DynamicSwap for Transformer...") | |
DynamicSwapInstaller.install_model(transformer, device=device) | |
print(" Installing DynamicSwap for Text Encoder 1...") | |
DynamicSwapInstaller.install_model(text_encoder, device=device) | |
print("Unloading models from GPU (Low VRAM setup)...") | |
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) | |
else: | |
print("Configuring for High VRAM mode (moving models to GPU)...") | |
text_encoder.to(device) | |
text_encoder_2.to(device) | |
image_encoder.to(device) | |
vae.to(device) | |
transformer.to(device) | |
print(" Models moved to GPU.") | |
# --- Prepare Inputs --- | |
print("Preparing inputs...") | |
prompt = args.prompt | |
n_prompt = args.negative_prompt | |
seed = args.seed | |
total_second_length = args.total_second_length | |
latent_window_size = args.latent_window_size | |
steps = args.steps | |
cfg = args.cfg | |
gs = args.distilled_guidance_scale | |
rs = args.rs | |
gpu_memory_preservation = args.gpu_memory_preservation | |
use_teacache = args.use_teacache | |
fps = args.fps | |
end_frame_path = args.end_frame | |
end_frame_influence = args.end_frame_influence | |
end_frame_weight = args.end_frame_weight | |
section_data = args.section_data | |
save_intermediate = args.save_intermediate_sections | |
save_section_frames = args.save_section_final_frames | |
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) | |
total_latent_sections = int(max(round(total_latent_sections), 1)) | |
print(f"Calculated total latent sections: {total_latent_sections}") | |
job_id = generate_timestamp() + f"_seed{seed}" | |
output_dir = Path(args.save_path) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
final_video_path = None | |
# --- Section Preprocessing Storage --- | |
section_latents = {} | |
section_image_embeddings = {} # Still store, might be useful later | |
section_prompt_embeddings = {} | |
try: | |
# --- Text Encoding (Global Prompts) --- | |
print("Encoding global text prompts...") | |
if not high_vram: | |
print(" Low VRAM mode: Loading Text Encoders to GPU...") | |
fake_diffusers_current_device(text_encoder, device) | |
load_model_as_complete(text_encoder_2, target_device=device) | |
print(" Text Encoders loaded.") | |
global_llama_vec, global_clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) | |
if cfg == 1.0: | |
print(" CFG scale is 1.0, using zero negative embeddings.") | |
global_llama_vec_n, global_clip_l_pooler_n = torch.zeros_like(global_llama_vec), torch.zeros_like(global_clip_l_pooler) | |
else: | |
print(f" Encoding negative prompt: '{n_prompt}'") | |
global_llama_vec_n, global_clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) | |
global_llama_vec, global_llama_attention_mask = crop_or_pad_yield_mask(global_llama_vec, length=512) | |
global_llama_vec_n, global_llama_attention_mask_n = crop_or_pad_yield_mask(global_llama_vec_n, length=512) | |
print(" Global text encoded and processed.") | |
# --- Section Text Encoding --- | |
if section_data: | |
print("Encoding section-specific prompts...") | |
for section_index, (img_path, prompt_text) in section_data.items(): | |
if prompt_text: | |
print(f" Encoding prompt for section {section_index}: '{prompt_text}'") | |
sec_llama_vec, sec_clip_pooler = encode_prompt_conds(prompt_text, text_encoder, text_encoder_2, tokenizer, tokenizer_2) | |
sec_llama_vec, _ = crop_or_pad_yield_mask(sec_llama_vec, length=512) | |
section_prompt_embeddings[section_index] = ( | |
sec_llama_vec.cpu().to(transformer.dtype), | |
sec_clip_pooler.cpu().to(transformer.dtype) | |
) | |
print(f" Section {section_index} prompt encoded and stored on CPU.") | |
else: | |
print(f" Section {section_index} has no specific prompt, will use global prompt.") | |
if not high_vram: | |
print(" Low VRAM mode: Unloading Text Encoders from GPU...") | |
unload_complete_models(text_encoder_2) | |
print(" Text Encoder 2 unloaded.") | |
# --- Input Image Processing & Dimension Calculation --- | |
print("Processing input image and determining dimensions...") | |
try: | |
input_image_np_orig, _, _ = preprocess_image_for_generation( | |
args.input_image, 1, 1, job_id, output_dir, "temp_input_orig" | |
) | |
except Exception as e: | |
print(f"Error loading input image '{args.input_image}' for dimension check: {e}") | |
raise | |
H_orig, W_orig, _ = input_image_np_orig.shape | |
print(f" Input image original size: {W_orig}x{H_orig}") | |
if args.width is not None and args.height is not None: | |
target_w, target_h = args.width, args.height | |
print(f" Using explicit target dimensions: {target_w}x{target_h}") | |
elif args.target_resolution is not None: | |
print(f" Calculating dimensions based on target resolution for longer side: {args.target_resolution}") | |
target_h, target_w = find_nearest_bucket(H_orig, W_orig, resolution=args.target_resolution) | |
print(f" Calculated dimensions (before adjustment): {target_w}x{target_h}") | |
else: | |
raise ValueError("Internal Error: Resolution determination failed.") | |
final_w = adjust_to_multiple(target_w, DIMENSION_MULTIPLE) | |
final_h = adjust_to_multiple(target_h, DIMENSION_MULTIPLE) | |
if final_w <= 0 or final_h <= 0: | |
print(f"Error: Calculated dimensions ({target_w}x{target_h}) resulted in non-positive dimensions after adjusting to be divisible by {DIMENSION_MULTIPLE} ({final_w}x{final_h}).") | |
raise ValueError("Adjusted dimensions are invalid.") | |
if final_w != target_w or final_h != target_h: | |
print(f"Warning: Adjusted dimensions from {target_w}x{target_h} to {final_w}x{final_h} to be divisible by {DIMENSION_MULTIPLE}.") | |
else: | |
print(f" Final dimensions confirmed: {final_w}x{final_h}") | |
width, height = final_w, final_h | |
if width * height > 1024 * 1024: | |
print(f"Warning: Target resolution {width}x{height} is large. Ensure you have sufficient VRAM.") | |
_, input_image_resized_np, input_image_pt = preprocess_image_for_generation( | |
args.input_image, width, height, job_id, output_dir, "input" | |
) | |
end_frame_resized_np = None | |
end_frame_pt = None | |
if end_frame_path: | |
_, end_frame_resized_np, end_frame_pt = preprocess_image_for_generation( | |
end_frame_path, width, height, job_id, output_dir, "end" | |
) | |
section_images_resized_np = {} | |
section_images_pt = {} | |
if section_data: | |
print("Processing section keyframe images...") | |
for section_index, (img_path, _) in section_data.items(): | |
_, sec_resized_np, sec_pt = preprocess_image_for_generation( | |
img_path, width, height, job_id, output_dir, f"section{section_index}" | |
) | |
section_images_resized_np[section_index] = sec_resized_np | |
section_images_pt[section_index] = sec_pt | |
# --- VAE Encoding --- | |
print("VAE encoding initial frame...") | |
if not high_vram: | |
print(" Low VRAM mode: Loading VAE to GPU...") | |
load_model_as_complete(vae, target_device=device) | |
print(" VAE loaded.") | |
input_image_pt_dev = input_image_pt.to(device=device, dtype=vae.dtype) | |
start_latent = vae_encode(input_image_pt_dev, vae) # GPU, vae.dtype | |
print(f" Initial latent shape: {start_latent.shape}") | |
print(f" Start latent stats - Min: {start_latent.min().item():.4f}, Max: {start_latent.max().item():.4f}, Mean: {start_latent.mean().item():.4f}") | |
end_frame_latent = None | |
if end_frame_pt is not None: | |
print("VAE encoding end frame...") | |
end_frame_pt_dev = end_frame_pt.to(device=device, dtype=vae.dtype) | |
end_frame_latent = vae_encode(end_frame_pt_dev, vae) # GPU, vae.dtype | |
print(f" End frame latent shape: {end_frame_latent.shape}") | |
print(f" End frame latent stats - Min: {end_frame_latent.min().item():.4f}, Max: {end_frame_latent.max().item():.4f}, Mean: {end_frame_latent.mean().item():.4f}") | |
if end_frame_latent.shape != start_latent.shape: | |
print(f"Warning: End frame latent shape mismatch. Reshaping.") | |
try: | |
end_frame_latent = end_frame_latent.reshape(start_latent.shape) | |
except Exception as reshape_err: | |
print(f"Error reshaping end frame latent: {reshape_err}. Disabling end frame.") | |
end_frame_latent = None | |
if section_images_pt: | |
print("VAE encoding section keyframes...") | |
for section_index, sec_pt in section_images_pt.items(): | |
sec_pt_dev = sec_pt.to(device=device, dtype=vae.dtype) | |
sec_latent = vae_encode(sec_pt_dev, vae) # GPU, vae.dtype | |
print(f" Section {section_index} latent shape: {sec_latent.shape}") | |
if sec_latent.shape != start_latent.shape: | |
print(f" Warning: Section {section_index} latent shape mismatch. Reshaping.") | |
try: | |
sec_latent = sec_latent.reshape(start_latent.shape) | |
except Exception as reshape_err: | |
print(f" Error reshaping section {section_index} latent: {reshape_err}. Skipping section latent.") | |
continue | |
# Store on CPU as float32 for context/blending later | |
section_latents[section_index] = sec_latent.cpu().float() | |
print(f" Section {section_index} latent encoded and stored on CPU.") | |
if not high_vram: | |
print(" Low VRAM mode: Unloading VAE from GPU...") | |
unload_complete_models(vae) | |
print(" VAE unloaded.") | |
# Move essential latents to CPU as float32 for context/blending | |
start_latent = start_latent.cpu().float() | |
if end_frame_latent is not None: | |
end_frame_latent = end_frame_latent.cpu().float() | |
# --- CLIP Vision Encoding --- | |
print("CLIP Vision encoding image(s)...") | |
if not high_vram: | |
print(" Low VRAM mode: Loading Image Encoder to GPU...") | |
load_model_as_complete(image_encoder, target_device=device) | |
print(" Image Encoder loaded.") | |
# Encode start frame - WILL BE USED CONSISTENTLY for image_embeddings | |
image_encoder_output = hf_clip_vision_encode(input_image_resized_np, feature_extractor, image_encoder) | |
start_image_embedding = image_encoder_output.last_hidden_state # GPU, image_encoder.dtype | |
print(f" Start image embedding shape: {start_image_embedding.shape}") | |
# Encode end frame (if provided) - Only needed if extending later | |
# end_frame_embedding = None # Not needed for this strategy | |
# if end_frame_resized_np is not None: | |
# pass # Skip encoding for now | |
# Encode section frames (if provided) - Store for potential future use | |
if section_images_resized_np: | |
print("CLIP Vision encoding section keyframes (storing on CPU)...") | |
for section_index, sec_resized_np in section_images_resized_np.items(): | |
sec_output = hf_clip_vision_encode(sec_resized_np, feature_extractor, image_encoder) | |
sec_embedding = sec_output.last_hidden_state | |
section_image_embeddings[section_index] = sec_embedding.cpu().to(transformer.dtype) | |
print(f" Section {section_index} embedding shape: {sec_embedding.shape}. Stored on CPU.") | |
if not high_vram: | |
print(" Low VRAM mode: Unloading Image Encoder from GPU...") | |
unload_complete_models(image_encoder) | |
print(" Image Encoder unloaded.") | |
# Move start image embedding to CPU (transformer dtype) | |
target_dtype = transformer.dtype | |
start_image_embedding = start_image_embedding.cpu().to(target_dtype) | |
# --- Prepare Global Embeddings for Transformer (CPU, transformer.dtype) --- | |
print("Preparing global embeddings for Transformer...") | |
global_llama_vec = global_llama_vec.cpu().to(target_dtype) | |
global_llama_vec_n = global_llama_vec_n.cpu().to(target_dtype) | |
global_clip_l_pooler = global_clip_l_pooler.cpu().to(target_dtype) | |
global_clip_l_pooler_n = global_clip_l_pooler_n.cpu().to(target_dtype) | |
print(f" Global Embeddings prepared on CPU with dtype {target_dtype}.") | |
# --- Sampling Setup --- | |
print("Setting up sampling...") | |
rnd = torch.Generator(cpu).manual_seed(seed) | |
num_frames = latent_window_size * 4 - 3 | |
print(f" Latent frames per sampling step (num_frames input): {num_frames}") | |
latent_c, latent_h, latent_w = start_latent.shape[1], start_latent.shape[3], start_latent.shape[4] | |
context_latents = torch.zeros(size=(1, latent_c, 1 + 2 + 16, latent_h, latent_w), dtype=torch.float32).cpu() | |
accumulated_generated_latents = None | |
history_pixels = None | |
latent_paddings = list(reversed(range(total_latent_sections))) | |
if total_latent_sections > 4: | |
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] | |
print(f" Using adjusted padding sequence for >4 sections: {latent_paddings}") | |
else: | |
print(f" Using standard padding sequence: {latent_paddings}") | |
# --- [MODIFIED] Restore Initial Context Initialization --- | |
if end_frame_latent is not None: | |
print(" Initializing context buffer's first slot with end frame latent.") | |
context_latents[:, :, 0:1, :, :] = end_frame_latent.cpu().float() # Ensure float32 CPU | |
else: | |
print(" No end frame latent available. Initial context remains zeros.") | |
# --- End Modified Context Initialization --- | |
# --- Main Sampling Loop (Generates Backward: End -> Start) --- | |
start_time = time.time() | |
num_loops = len(latent_paddings) | |
for i_loop, latent_padding in enumerate(latent_paddings): | |
section_start_time = time.time() | |
current_section_index_from_end = latent_padding | |
is_first_generation_step = (i_loop == 0) | |
is_last_generation_step = (latent_padding == 0) | |
print(f"\n--- Starting Generation Step {i_loop+1}/{num_loops} (Section Index from End: {current_section_index_from_end}, First Step: {is_first_generation_step}, Last Step: {is_last_generation_step}) ---") | |
latent_padding_size = latent_padding * latent_window_size | |
print(f' Padding size (latent frames): {latent_padding_size}, Window size (latent frames): {latent_window_size}') | |
# --- Select Conditioning Inputs for this Section --- | |
# 1. Conditioning Latent (`clean_latents_pre`) - Calculate Blend | |
# Determine the base latent (start or section-specific) | |
base_conditioning_latent = start_latent # Default to start (float32 CPU) | |
if current_section_index_from_end in section_latents: | |
base_conditioning_latent = section_latents[current_section_index_from_end] # Use section if available (float32 CPU) | |
print(f" Using SECTION {current_section_index_from_end} latent as base conditioning latent.") | |
else: | |
print(f" Using START frame latent as base conditioning latent.") | |
# Apply 'bookend' override to the base latent for the first step only | |
if end_frame_influence == "bookend" and is_first_generation_step and end_frame_latent is not None: | |
if current_section_index_from_end not in section_latents: | |
base_conditioning_latent = end_frame_latent # float32 CPU | |
print(" Applying 'bookend': Overriding base conditioning latent with END frame latent for first step.") | |
# Blend the base conditioning latent with the end frame latent based on mode/weight | |
current_conditioning_latent = base_conditioning_latent # Initialize with base | |
current_end_frame_latent_weight = 0.0 | |
if end_frame_latent is not None: # Only blend if end frame exists | |
if end_frame_influence == 'progressive': | |
progress = i_loop / max(1, num_loops - 1) | |
current_end_frame_latent_weight = args.end_frame_weight * (1.0 - progress) | |
elif end_frame_influence == 'half': | |
if i_loop < num_loops / 2: | |
current_end_frame_latent_weight = args.end_frame_weight | |
# For 'last' and 'bookend', weight remains 0, no blending needed | |
current_end_frame_latent_weight = max(0.0, min(1.0, current_end_frame_latent_weight)) | |
if current_end_frame_latent_weight > 1e-4: # Mix only if weight is significant | |
print(f" Blending Conditioning Latent: Base<-{1.0-current_end_frame_latent_weight:.3f} | End->{current_end_frame_latent_weight:.3f} (Mode: {end_frame_influence})") | |
# Ensure both inputs to mix_latents are float32 CPU | |
current_conditioning_latent = mix_latents(base_conditioning_latent.cpu().float(), | |
end_frame_latent.cpu().float(), | |
current_end_frame_latent_weight) | |
#else: | |
# print(f" Using BASE conditioning latent (Mode: {end_frame_influence}, Blend Weight near zero).") # Can be verbose | |
#else: | |
# print(f" Using BASE conditioning latent (No end frame specified for blending).") # Can be verbose | |
# 2. Image Embedding - Use Fixed Start Embedding | |
current_image_embedding = start_image_embedding # transformer.dtype CPU | |
print(f" Using fixed START frame image embedding.") | |
# 3. Text Embedding (Select section or global) | |
if current_section_index_from_end in section_prompt_embeddings: | |
current_llama_vec, current_clip_pooler = section_prompt_embeddings[current_section_index_from_end] | |
print(f" Using SECTION {current_section_index_from_end} prompt embeddings.") | |
else: | |
current_llama_vec = global_llama_vec | |
current_clip_pooler = global_clip_l_pooler | |
print(f" Using GLOBAL prompt embeddings.") | |
current_llama_vec_n = global_llama_vec_n | |
current_clip_pooler_n = global_clip_l_pooler_n | |
current_llama_attention_mask = global_llama_attention_mask | |
current_llama_attention_mask_n = global_llama_attention_mask_n | |
# --- Prepare Sampler Inputs --- | |
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) | |
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = \ | |
indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) | |
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) | |
# Prepare conditioning latents (float32 CPU) | |
clean_latents_pre = current_conditioning_latent # Use the potentially blended one | |
clean_latents_post, clean_latents_2x, clean_latents_4x = \ | |
context_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) | |
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) | |
print(f" Final Conditioning shapes (CPU): clean={clean_latents.shape}, 2x={clean_latents_2x.shape}, 4x={clean_latents_4x.shape}") | |
print(f" Clean Latents Pre stats - Min: {clean_latents_pre.min().item():.4f}, Max: {clean_latents_pre.max().item():.4f}, Mean: {clean_latents_pre.mean().item():.4f}") | |
# Load Transformer (Low VRAM) | |
if not high_vram: | |
print(" Moving Transformer to GPU...") | |
unload_complete_models() | |
move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) | |
fake_diffusers_current_device(text_encoder, device) | |
# Configure TeaCache | |
if use_teacache: | |
transformer.initialize_teacache(enable_teacache=True, num_steps=steps) | |
print(" TeaCache enabled.") | |
else: | |
transformer.initialize_teacache(enable_teacache=False) | |
print(" TeaCache disabled.") | |
# --- Run Sampling --- | |
print(f" Starting sampling ({steps} steps) for {num_frames} latent frames...") | |
sampling_step_start_time = time.time() | |
pbar = tqdm(total=steps, desc=f" Section {current_section_index_from_end} Sampling", leave=False) | |
def callback(d): | |
pbar.update(1) | |
return | |
current_sampler_device = transformer.device | |
current_text_encoder_device = text_encoder.device if not high_vram else device | |
# Move tensors to device just before sampling | |
_prompt_embeds = current_llama_vec.to(current_text_encoder_device) | |
_prompt_embeds_mask = current_llama_attention_mask.to(current_text_encoder_device) | |
_prompt_poolers = current_clip_pooler.to(current_sampler_device) | |
_negative_prompt_embeds = current_llama_vec_n.to(current_text_encoder_device) | |
_negative_prompt_embeds_mask = current_llama_attention_mask_n.to(current_text_encoder_device) | |
_negative_prompt_poolers = current_clip_pooler_n.to(current_sampler_device) | |
_image_embeddings = current_image_embedding.to(current_sampler_device) # Fixed start embedding | |
_latent_indices = latent_indices.to(current_sampler_device) | |
# Pass conditioning latents (now potentially blended) to sampler | |
_clean_latents = clean_latents.to(current_sampler_device, dtype=transformer.dtype) | |
_clean_latent_indices = clean_latent_indices.to(current_sampler_device) | |
_clean_latents_2x = clean_latents_2x.to(current_sampler_device, dtype=transformer.dtype) | |
_clean_latent_2x_indices = clean_latent_2x_indices.to(current_sampler_device) | |
_clean_latents_4x = clean_latents_4x.to(current_sampler_device, dtype=transformer.dtype) | |
_clean_latent_4x_indices = clean_latent_4x_indices.to(current_sampler_device) | |
generated_latents_gpu = sample_hunyuan( | |
transformer=transformer, | |
sampler='unipc', | |
width=width, | |
height=height, | |
frames=num_frames, | |
real_guidance_scale=cfg, | |
distilled_guidance_scale=gs, | |
guidance_rescale=rs, | |
num_inference_steps=steps, | |
generator=rnd, | |
prompt_embeds=_prompt_embeds, | |
prompt_embeds_mask=_prompt_embeds_mask, | |
prompt_poolers=_prompt_poolers, | |
negative_prompt_embeds=_negative_prompt_embeds, | |
negative_prompt_embeds_mask=_negative_prompt_embeds_mask, | |
negative_prompt_poolers=_negative_prompt_poolers, | |
device=current_sampler_device, | |
dtype=transformer.dtype, | |
image_embeddings=_image_embeddings, # Using fixed start embedding | |
latent_indices=_latent_indices, | |
clean_latents=_clean_latents, # Using potentially blended latents | |
clean_latent_indices=_clean_latent_indices, | |
clean_latents_2x=_clean_latents_2x, | |
clean_latent_2x_indices=_clean_latent_2x_indices, | |
clean_latents_4x=_clean_latents_4x, | |
clean_latent_4x_indices=_clean_latent_4x_indices, | |
callback=callback, | |
) | |
pbar.close() | |
sampling_step_end_time = time.time() | |
print(f" Sampling finished in {sampling_step_end_time - sampling_step_start_time:.2f} seconds.") | |
print(f" Raw generated latent shape for this step: {generated_latents_gpu.shape}") | |
print(f" Generated latents stats (GPU) - Min: {generated_latents_gpu.min().item():.4f}, Max: {generated_latents_gpu.max().item():.4f}, Mean: {generated_latents_gpu.mean().item():.4f}") | |
# Move generated latents to CPU as float32 | |
generated_latents_cpu = generated_latents_gpu.cpu().float() | |
del generated_latents_gpu, _prompt_embeds, _prompt_embeds_mask, _prompt_poolers, _negative_prompt_embeds, _negative_prompt_embeds_mask, _negative_prompt_poolers | |
del _image_embeddings, _latent_indices, _clean_latents, _clean_latent_indices, _clean_latents_2x, _clean_latent_2x_indices, _clean_latents_4x, _clean_latent_4x_indices | |
if device.type == 'cuda': torch.cuda.empty_cache() | |
# Offload Transformer and TE1 (Low VRAM) | |
if not high_vram: | |
print(" Low VRAM mode: Offloading Transformer and Text Encoder from GPU...") | |
offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) | |
offload_model_from_device_for_memory_preservation(text_encoder, target_device=device, preserved_memory_gb=gpu_memory_preservation) | |
print(" Transformer and Text Encoder offloaded.") | |
# --- History/Context Update --- | |
if is_last_generation_step: | |
print(" Last generation step: Prepending start frame latent to generated latents.") | |
generated_latents_cpu = torch.cat([start_latent.cpu().float(), generated_latents_cpu], dim=2) | |
print(f" Shape after prepending start latent: {generated_latents_cpu.shape}") | |
context_latents = torch.cat([generated_latents_cpu, context_latents], dim=2) | |
print(f" Context buffer updated. New shape: {context_latents.shape}") | |
# Accumulate the generated latents for the final video output | |
if accumulated_generated_latents is None: | |
accumulated_generated_latents = generated_latents_cpu | |
else: | |
accumulated_generated_latents = torch.cat([generated_latents_cpu, accumulated_generated_latents], dim=2) | |
current_total_latent_frames = accumulated_generated_latents.shape[2] | |
print(f" Accumulated generated latents updated. Total latent frames: {current_total_latent_frames}") | |
print(f" Accumulated latents stats - Min: {accumulated_generated_latents.min().item():.4f}, Max: {accumulated_generated_latents.max().item():.4f}, Mean: {accumulated_generated_latents.mean().item():.4f}") | |
# --- VAE Decoding & Merging --- | |
print(" Decoding generated latents and merging video...") | |
decode_start_time = time.time() | |
if not high_vram: | |
print(" Moving VAE to GPU...") | |
offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) | |
unload_complete_models(text_encoder, text_encoder_2, image_encoder) | |
load_model_as_complete(vae, target_device=device) | |
print(" VAE loaded.") | |
print(f" Decoding current section's latents (shape: {generated_latents_cpu.shape}) for append.") | |
latents_to_decode_for_append = generated_latents_cpu.to(device=device, dtype=vae.dtype) | |
current_pixels = vae_decode(latents_to_decode_for_append, vae).cpu().float() # Decode and move to CPU float32 | |
print(f" Decoded pixels for append shape: {current_pixels.shape}") | |
del latents_to_decode_for_append | |
if device.type == 'cuda': torch.cuda.empty_cache() | |
if history_pixels is None: | |
history_pixels = current_pixels | |
print(f" Initialized history_pixels shape: {history_pixels.shape}") | |
else: | |
append_overlap = 3 | |
print(f" Appending section with pixel overlap: {append_overlap}") | |
history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlap=append_overlap) | |
print(f" Appended. New total pixel shape: {history_pixels.shape}") | |
if not high_vram: | |
print(" Low VRAM mode: Unloading VAE from GPU...") | |
unload_complete_models(vae) | |
print(" VAE unloaded.") | |
decode_end_time = time.time() | |
print(f" Decoding and merging finished in {decode_end_time - decode_start_time:.2f} seconds.") | |
# --- Save Intermediate/Section Output --- | |
current_num_pixel_frames = history_pixels.shape[2] | |
if save_section_frames: | |
try: | |
first_frame_index = 0 # Index 0 of the newly decoded chunk is the first frame generated in this step | |
frame_to_save = current_pixels[0, :, first_frame_index, :, :] | |
frame_to_save = einops.rearrange(frame_to_save, 'c h w -> h w c') | |
frame_to_save_np = frame_to_save.cpu().numpy() | |
frame_to_save_np = np.clip((frame_to_save_np * 127.5 + 127.5), 0, 255).astype(np.uint8) | |
section_frame_filename = output_dir / f'{job_id}_section_start_frame_idx{current_section_index_from_end}.png' # Renamed for clarity | |
Image.fromarray(frame_to_save_np).save(section_frame_filename) | |
print(f" Saved first generated pixel frame of section {current_section_index_from_end} (from decoded chunk) to: {section_frame_filename}") | |
except Exception as e: | |
print(f" [WARN] Error saving section {current_section_index_from_end} start frame image: {e}") | |
if save_intermediate or is_last_generation_step: | |
output_filename = output_dir / f'{job_id}_step{i_loop+1}_idx{current_section_index_from_end}_frames{current_num_pixel_frames}_{width}x{height}.mp4' | |
print(f" Saving {'intermediate' if not is_last_generation_step else 'final'} video ({current_num_pixel_frames} frames) to: {output_filename}") | |
try: | |
save_bcthw_as_mp4(history_pixels.float(), str(output_filename), fps=int(fps)) | |
print(f" Saved video using save_bcthw_as_mp4") | |
if not is_last_generation_step: | |
print(f"INTERMEDIATE_VIDEO_PATH:{output_filename}") | |
final_video_path = str(output_filename) | |
except Exception as e: | |
print(f" Error saving video using save_bcthw_as_mp4: {e}") | |
traceback.print_exc() | |
# Fallback save attempt | |
try: | |
first_frame_img = history_pixels.float()[0, :, 0].permute(1, 2, 0).cpu().numpy() | |
first_frame_img = (first_frame_img * 127.5 + 127.5).clip(0, 255).astype(np.uint8) | |
frame_path = str(output_filename).replace('.mp4', '_first_frame_ERROR.png') | |
Image.fromarray(first_frame_img).save(frame_path) | |
print(f" Saved first frame as image to {frame_path} due to video saving error.") | |
except Exception as frame_err: | |
print(f" Could not save first frame either: {frame_err}") | |
section_end_time = time.time() | |
print(f"--- Generation Step {i_loop+1} finished in {section_end_time - section_start_time:.2f} seconds ---") | |
if is_last_generation_step: | |
print("\nFinal generation step completed.") | |
break | |
# --- Final Video Saved During Last Step --- | |
if final_video_path and os.path.exists(final_video_path): | |
print(f"\nSuccessfully generated: {final_video_path}") | |
print(f"ACTUAL_FINAL_PATH:{final_video_path}") | |
return final_video_path | |
else: | |
print("\nError: Final video path not found or not saved correctly.") | |
return None | |
except Exception as e: | |
print("\n--- ERROR DURING GENERATION ---") | |
traceback.print_exc() | |
print("-----------------------------") | |
if 'history_pixels' in locals() and history_pixels is not None and history_pixels.shape[2] > 0: | |
partial_output_name = output_dir / f"{job_id}_partial_ERROR_{history_pixels.shape[2]}_frames_{width}x{height}.mp4" | |
print(f"Attempting to save partial video to: {partial_output_name}") | |
try: | |
save_bcthw_as_mp4(history_pixels.float(), str(partial_output_name), fps=fps) | |
print(f"ACTUAL_FINAL_PATH:{partial_output_name}") | |
return str(partial_output_name) | |
except Exception as save_err: | |
print(f"Error saving partial video during error handling: {save_err}") | |
traceback.print_exc() | |
print("Status: Error occurred, no video saved.") | |
return None | |
finally: | |
print("Performing final model cleanup...") | |
try: | |
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) | |
except Exception as e: | |
print(f"Error during final model unload: {e}") | |
pass | |
if device.type == 'cuda': | |
torch.cuda.empty_cache() | |
print("CUDA cache cleared.") | |
def main(): | |
args = parse_args() | |
models = load_models(args) | |
final_path = generate_video(args, models) | |
if final_path: | |
print(f"\nVideo generation finished. Final path: {final_path}") | |
sys.exit(0) | |
else: | |
print("\nVideo generation failed.") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |