Framepack-H111 / f_video_end_cli_local.py
rahul7star's picture
Upload 303 files
e0336bc verified
import os
import torch
import traceback
import einops
import numpy as np
import argparse
import math
import decord
from tqdm import tqdm
import pathlib
from datetime import datetime
import imageio_ffmpeg
import tempfile
import shutil
import subprocess
import sys
from PIL import Image
try:
from frame_pack.hunyuan_video_packed import load_packed_model
from frame_pack.framepack_utils import (
load_vae,
load_text_encoder1,
load_text_encoder2,
load_image_encoders
)
from frame_pack.hunyuan import encode_prompt_conds, vae_decode, vae_encode # vae_decode_fake might be needed for previews if added
from frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
from frame_pack.k_diffusion_hunyuan import sample_hunyuan
from frame_pack.clip_vision import hf_clip_vision_encode
from frame_pack.bucket_tools import find_nearest_bucket
from diffusers_helper.utils import save_bcthw_as_mp4 # from a common helper library
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, \
move_model_to_device_with_memory_preservation, \
offload_model_from_device_for_memory_preservation, \
fake_diffusers_current_device, DynamicSwapInstaller, \
unload_complete_models, load_model_as_complete
# For LoRA
from networks import lora_framepack
try:
from lycoris.kohya import create_network_from_weights
except ImportError:
pass # Lycoris optional
from base_wan_generate_video import merge_lora_weights # Assuming this is accessible
except ImportError as e:
print(f"Error importing FramePack related modules: {e}. Ensure they are in PYTHONPATH.")
sys.exit(1)
# --- Global Model Variables ---
text_encoder = None
text_encoder_2 = None
tokenizer = None
tokenizer_2 = None
vae = None
feature_extractor = None
image_encoder = None
transformer = None
high_vram = False
free_mem_gb = 0.0
outputs_folder = './outputs/' # Default, can be overridden by --output_dir
@torch.no_grad()
def video_encode(video_path, resolution, no_resize, vae_model, vae_batch_size=16, device="cuda", width=None, height=None):
video_path = str(pathlib.Path(video_path).resolve())
print(f"Processing video for encoding: {video_path}")
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available, falling back to CPU for video_encode")
device = "cpu"
try:
print("Initializing VideoReader...")
vr = decord.VideoReader(video_path)
fps = vr.get_avg_fps()
if fps == 0:
print("Warning: VideoReader reported FPS as 0. Attempting to get it via OpenCV.")
import cv2
cap = cv2.VideoCapture(video_path)
fps_cv = cap.get(cv2.CAP_PROP_FPS)
cap.release()
if fps_cv > 0:
fps = fps_cv
print(f"Using FPS from OpenCV: {fps}")
else:
# Fallback FPS if all else fails
fps = 25
print(f"Failed to determine FPS for the input video. Defaulting to {fps} FPS.")
num_real_frames = len(vr)
print(f"Video loaded: {num_real_frames} frames, FPS: {fps}")
latent_size_factor = 4 # Hunyuan VAE downsamples by 8, but generation often uses 4x frame groups
num_frames = (num_real_frames // latent_size_factor) * latent_size_factor
if num_frames != num_real_frames:
print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility (multiple of {latent_size_factor})")
if num_frames == 0:
raise ValueError(f"Video too short ({num_real_frames} frames) or becomes 0 after truncation. Needs at least {latent_size_factor} frames.")
num_real_frames = num_frames
print("Reading video frames...")
frames_np_all = vr.get_batch(range(num_real_frames)).asnumpy()
print(f"Frames read: {frames_np_all.shape}")
native_height, native_width = frames_np_all.shape[1], frames_np_all.shape[2]
print(f"Native video resolution: {native_width}x{native_height}")
target_h_arg = native_height if height is None else height
target_w_arg = native_width if width is None else width
if not no_resize:
actual_target_height, actual_target_width = find_nearest_bucket(target_h_arg, target_w_arg, resolution=resolution)
print(f"Adjusted resolution for VAE encoding: {actual_target_width}x{actual_target_height}")
else:
actual_target_width = (native_width // 8) * 8
actual_target_height = (native_height // 8) * 8
if actual_target_width != native_width or actual_target_height != native_height:
print(f"Using native resolution, adjusted to be divisible by 8: {actual_target_width}x{actual_target_height}")
else:
print(f"Using native resolution without resizing: {actual_target_width}x{actual_target_height}")
processed_frames_list = []
for frame_idx in range(frames_np_all.shape[0]):
frame = frames_np_all[frame_idx]
frame_resized_np = resize_and_center_crop(frame, target_width=actual_target_width, target_height=actual_target_height)
processed_frames_list.append(frame_resized_np)
processed_frames_np_stack = np.stack(processed_frames_list)
print(f"Frames preprocessed: {processed_frames_np_stack.shape}")
input_image_np_for_clip_first = processed_frames_np_stack[0]
input_image_np_for_clip_last = processed_frames_np_stack[-1]
print("Converting frames to tensor...")
frames_pt = torch.from_numpy(processed_frames_np_stack).float() / 127.5 - 1.0
frames_pt = frames_pt.permute(0, 3, 1, 2) # B, H, W, C -> B, C, H, W
frames_pt = frames_pt.unsqueeze(0).permute(0, 2, 1, 3, 4) # B, C, H, W -> 1, C, B, H, W (as VAE expects 1,C,F,H,W)
print(f"Tensor shape for VAE: {frames_pt.shape}")
input_video_pixels_cpu = frames_pt.clone().cpu()
print(f"Moving VAE and tensor to device: {device}")
vae_model.to(device)
frames_pt = frames_pt.to(device)
print(f"Encoding input video frames with VAE (batch size: {vae_batch_size})")
all_latents_list = []
vae_model.eval()
with torch.no_grad():
for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="VAE Encoding Video Frames", mininterval=0.1):
batch_frames_pt = frames_pt[:, :, i:i + vae_batch_size]
try:
batch_latents = vae_encode(batch_frames_pt, vae_model)
all_latents_list.append(batch_latents.cpu())
except RuntimeError as e:
print(f"Error during VAE encoding: {str(e)}")
if "out of memory" in str(e).lower() and device == "cuda":
print("CUDA out of memory during VAE encoding. Try reducing --vae_batch_size or use CPU for VAE.")
raise
history_latents_cpu = torch.cat(all_latents_list, dim=2)
print(f"History latents shape (original video): {history_latents_cpu.shape}")
start_latent_cpu = history_latents_cpu[:, :, :1].clone()
end_of_input_video_latent_cpu = history_latents_cpu[:, :, -1:].clone()
print(f"Start latent shape (for conditioning): {start_latent_cpu.shape}")
print(f"End of input video latent shape: {end_of_input_video_latent_cpu.shape}")
if device == "cuda":
vae_model.to(cpu) # Move VAE back to CPU
torch.cuda.empty_cache()
print("VAE moved back to CPU, CUDA cache cleared")
return (start_latent_cpu, input_image_np_for_clip_first,
history_latents_cpu, fps,
actual_target_height, actual_target_width,
input_video_pixels_cpu,
end_of_input_video_latent_cpu, input_image_np_for_clip_last)
except Exception as e:
print(f"Error in video_encode: {str(e)}")
traceback.print_exc()
raise
@torch.no_grad()
def image_encode(image_np, target_width, target_height, vae_model, image_encoder_model, feature_extractor_model, device="cuda"):
"""
Encode a single image into a latent and compute its CLIP vision embedding.
"""
global high_vram # Use global high_vram status
print("Processing single image for encoding (e.g., end_frame)...")
try:
print(f"Using target resolution for image encoding: {target_width}x{target_height}")
processed_image_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height)
image_pt = torch.from_numpy(processed_image_np).float() / 127.5 - 1.0
image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).unsqueeze(2) # N C F H W (N=1, F=1)
target_vae_device = device
if not high_vram: load_model_as_complete(vae_model, target_device=target_vae_device)
else: vae_model.to(target_vae_device)
image_pt_device = image_pt.to(target_vae_device)
latent = vae_encode(image_pt_device, vae_model).cpu() # Encode and move to CPU
print(f"Single image VAE output shape (latent): {latent.shape}")
if not high_vram: unload_complete_models(vae_model) # Offload VAE if low VRAM
target_img_enc_device = device
if not high_vram: load_model_as_complete(image_encoder_model, target_device=target_img_enc_device)
else: image_encoder_model.to(target_img_enc_device)
clip_embedding_output = hf_clip_vision_encode(processed_image_np, feature_extractor_model, image_encoder_model)
clip_embedding = clip_embedding_output.last_hidden_state.cpu() # Encode and move to CPU
print(f"Single image CLIP embedding shape: {clip_embedding.shape}")
if not high_vram: unload_complete_models(image_encoder_model) # Offload image encoder if low VRAM
if device == "cuda":
torch.cuda.empty_cache()
# print("CUDA cache cleared after single image_encode")
return latent, clip_embedding, processed_image_np
except Exception as e:
print(f"Error in image_encode: {str(e)}")
traceback.print_exc()
raise
def set_mp4_comments_imageio_ffmpeg(input_file, comments):
try:
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
if not os.path.exists(input_file):
print(f"Error: Input file {input_file} does not exist")
return False
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
command = [
ffmpeg_path, '-i', input_file, '-metadata', f'comment={comments}',
'-c:v', 'copy', '-c:a', 'copy', '-y', temp_file
]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False)
if result.returncode == 0:
shutil.move(temp_file, input_file)
print(f"Successfully added comments to {input_file}")
return True
else:
if os.path.exists(temp_file): os.remove(temp_file)
print(f"Error: FFmpeg failed with message:\n{result.stderr}")
return False
except Exception as e:
if 'temp_file' in locals() and os.path.exists(temp_file): os.remove(temp_file)
print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e))
return False
@torch.no_grad()
def do_generation_work(
input_video_path, prompt, n_prompt, seed,
end_frame_path, end_frame_weight, # New arguments
resolution_max_dim,
additional_second_length,
latent_window_size, steps, cfg, gs, rs,
gpu_memory_preservation, use_teacache, no_resize, mp4_crf,
num_clean_frames, vae_batch_size,
extension_only
):
global high_vram, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, feature_extractor, image_encoder, transformer, args
print('--- Starting Video Generation (with End Frame support) ---')
try:
# --- Text Encoding ---
print('Text encoding...')
target_text_enc_device = str(gpu if torch.cuda.is_available() else cpu)
if not high_vram:
if text_encoder: fake_diffusers_current_device(text_encoder, target_text_enc_device) # DynamicSwapInstaller for text_encoder
if text_encoder_2: load_model_as_complete(text_encoder_2, target_device=target_text_enc_device)
else:
if text_encoder: text_encoder.to(target_text_enc_device)
if text_encoder_2: text_encoder_2.to(target_text_enc_device)
llama_vec_gpu, clip_l_pooler_gpu = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
if cfg == 1.0: # Note: Original FramePack usually uses gs, cfg=1 means gs is active
llama_vec_n_gpu, clip_l_pooler_n_gpu = torch.zeros_like(llama_vec_gpu), torch.zeros_like(clip_l_pooler_gpu)
else: # If cfg > 1.0, it implies standard CFG, so n_prompt is used. gs should be 1.0 in this case.
llama_vec_n_gpu, clip_l_pooler_n_gpu = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
# Store on CPU
llama_vec_padded_cpu, llama_attention_mask_cpu = crop_or_pad_yield_mask(llama_vec_gpu.cpu(), length=512)
llama_vec_n_padded_cpu, llama_attention_mask_n_cpu = crop_or_pad_yield_mask(llama_vec_n_gpu.cpu(), length=512)
clip_l_pooler_cpu = clip_l_pooler_gpu.cpu()
clip_l_pooler_n_cpu = clip_l_pooler_n_gpu.cpu()
if not high_vram: unload_complete_models(text_encoder_2) # text_encoder is managed by DynamicSwap
# --- Video and End Frame Encoding ---
print('Encoding input video...')
video_encode_device = str(gpu if torch.cuda.is_available() else cpu)
(start_latent_input_cpu, input_image_np_first,
video_latents_history_cpu, fps, height, width,
input_video_pixels_cpu,
end_of_input_video_latent_cpu, input_image_np_last) = video_encode(
input_video_path, resolution_max_dim, no_resize, vae,
vae_batch_size=vae_batch_size, device=video_encode_device,
width=None, height=None # video_encode will use resolution_max_dim
)
if fps <= 0: raise ValueError("FPS from input video is 0 or invalid.")
end_latent_from_file_cpu, end_clip_embedding_from_file_cpu = None, None
if end_frame_path:
print(f"Encoding provided end frame from: {end_frame_path}")
end_frame_pil = Image.open(end_frame_path).convert("RGB")
end_frame_np = np.array(end_frame_pil)
end_latent_from_file_cpu, end_clip_embedding_from_file_cpu, _ = image_encode(
end_frame_np, target_width=width, target_height=height,
vae_model=vae, image_encoder_model=image_encoder,
feature_extractor_model=feature_extractor, device=video_encode_device
)
# --- CLIP Vision Encoding for first and last frames of input video ---
print('CLIP Vision encoding for input video frames...')
target_img_enc_device = str(gpu if torch.cuda.is_available() else cpu)
if not high_vram: load_model_as_complete(image_encoder, target_device=target_img_enc_device)
else: image_encoder.to(target_img_enc_device)
# For original FramePack, image_embeddings in sample_hunyuan often comes from the *start* image.
# Script 2 uses end_of_input_video_embedding or a blend with the explicit end_frame.
# We will follow script 2 for conditioning.
# start_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_first, feature_extractor, image_encoder).last_hidden_state.cpu()
end_of_input_video_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_last, feature_extractor, image_encoder).last_hidden_state.cpu()
if not high_vram: unload_complete_models(image_encoder)
# Determine final image embedding for sampling loop
if end_clip_embedding_from_file_cpu is not None:
print(f"Blending end-of-input-video embedding with provided end_frame embedding (weight: {end_frame_weight})")
final_clip_embedding_for_sampling_cpu = \
(1.0 - end_frame_weight) * end_of_input_video_clip_embedding_cpu + \
end_frame_weight * end_clip_embedding_from_file_cpu
else:
print("Using end-of-input-video's last frame embedding for image conditioning.")
final_clip_embedding_for_sampling_cpu = end_of_input_video_clip_embedding_cpu.clone()
# --- Prepare for Sampling Loop ---
target_transformer_device = str(gpu if torch.cuda.is_available() else cpu)
if not high_vram:
if transformer: move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation)
else:
if transformer: transformer.to(target_transformer_device)
cond_device = transformer.device
cond_dtype = transformer.dtype
# Move conditioning tensors to transformer's device and dtype
llama_vec = llama_vec_padded_cpu.to(device=cond_device, dtype=cond_dtype)
llama_attention_mask = llama_attention_mask_cpu.to(device=cond_device) # Mask is usually bool/int
clip_l_pooler = clip_l_pooler_cpu.to(device=cond_device, dtype=cond_dtype)
llama_vec_n = llama_vec_n_padded_cpu.to(device=cond_device, dtype=cond_dtype)
llama_attention_mask_n = llama_attention_mask_n_cpu.to(device=cond_device)
clip_l_pooler_n = clip_l_pooler_n_cpu.to(device=cond_device, dtype=cond_dtype)
# This is the image embedding that will be used in the sampling loop
image_embeddings_for_sampling_loop = final_clip_embedding_for_sampling_cpu.to(device=cond_device, dtype=cond_dtype)
# start_latent_for_initial_cond_gpu is the first frame of input video, used for clean_latents_pre
# However, script 2 uses `video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):]` for clean_latents_pre.
# And `start_latent` for sample_hunyuan's `clean_latents` is `torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)`
# For backward generation, the "start_latent" concept for `sample_hunyuan`'s `clean_latents` argument
# is often the *last frame of the input video* when generating the chunk closest to the input video.
# Let's use end_of_input_video_latent_cpu for this role when appropriate.
num_output_pixel_frames_per_section = latent_window_size * 4 # Not -3 here, as this is for total section calc
if num_output_pixel_frames_per_section == 0:
raise ValueError("latent_window_size * 4 is zero, cannot calculate total_extension_latent_sections.")
total_extension_latent_sections = int(max(round((additional_second_length * fps) / num_output_pixel_frames_per_section), 1))
print(f"Input video FPS: {fps}, Target additional length: {additional_second_length}s")
print(f"Generating {total_extension_latent_sections} new sections for extension (approx {total_extension_latent_sections * num_output_pixel_frames_per_section / fps:.2f}s).")
job_id_base = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + \
f"_framepack-vidEndFrm_{width}x{height}_{additional_second_length:.1f}s_seed{seed}_s{steps}_gs{gs}_cfg{cfg}"
job_id = job_id_base
if args.extension_only: # <<< Access args directly
job_id += "_extonly"
print("Extension-only mode enabled. Filenames will reflect this.")
rnd = torch.Generator("cpu").manual_seed(seed)
# Initialize history for generated latents (starts empty or with end_latent_from_file)
if end_latent_from_file_cpu is not None:
# This assumes end_latent_from_file_cpu is [1,C,1,H,W], we might need more frames if it's a seed
# Script 2's logic for clean_latents_post when is_end_of_video seems to use just 1 frame.
history_latents_generated_cpu = end_latent_from_file_cpu.clone()
else:
channels_dim = video_latents_history_cpu.shape[1] # Get from input video latents
latent_h, latent_w = height // 8, width // 8
history_latents_generated_cpu = torch.empty((1, channels_dim, 0, latent_h, latent_w), dtype=torch.float32, device='cpu')
# Initialize history for decoded pixels (starts empty)
history_pixels_decoded_cpu = None
total_generated_latent_frames_count = history_latents_generated_cpu.shape[2]
previous_video_path_for_cleanup = None
# Backward generation loop (from demo_gradio_video+endframe.py)
latent_paddings = list(reversed(range(total_extension_latent_sections)))
if total_extension_latent_sections > 4: # Heuristic from script 2
latent_paddings = [3] + [2] * (total_extension_latent_sections - 3) + [1, 0]
for loop_idx, latent_padding_val in enumerate(latent_paddings):
current_section_num_from_end = loop_idx + 1
is_start_of_extension = (latent_padding_val == 0) # This is the chunk closest to input video
is_end_of_extension = (latent_padding_val == latent_paddings[0]) # This is the chunk furthest from input video
print(f"--- Generating Extension: Seed {seed}: Section {current_section_num_from_end}/{total_extension_latent_sections} (backward), padding={latent_padding_val} ---")
if transformer: transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps if use_teacache else 0)
progress_bar_sampler = tqdm(total=steps, desc=f"Sampling Extension Section {current_section_num_from_end}/{total_extension_latent_sections}", file=sys.stdout, dynamic_ncols=True)
def sampler_callback_cli(d): progress_bar_sampler.update(1)
# Context frame calculation (from demo_gradio_video+endframe.py worker)
# `available_frames` for context refers to previously *generated* frames or input video frames
# For `clean_latents_pre`, it's always from `video_latents_history_cpu`
# For `clean_latents_post`, `_2x`, `_4x`, it's from `history_latents_generated_cpu`
effective_clean_frames_count = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
# For clean_latents_pre (from input video)
# If is_start_of_extension, we might want stronger anchoring to input video. Script 2 uses full `effective_clean_frames_count`.
clean_latent_pre_frames_num = effective_clean_frames_count
if is_start_of_extension: # Closest to input video
clean_latent_pre_frames_num = 1 # Script 2 uses 1 to avoid jumpcuts from input video when generating chunk closest to it.
# For clean_latents_post, _2x, _4x (from previously generated extension chunks)
available_generated_latents = history_latents_generated_cpu.shape[2]
# `post_frames_num` is for clean_latents_post
post_frames_num = 1 if is_end_of_extension and end_latent_from_file_cpu is not None else effective_clean_frames_count
if is_end_of_extension and end_latent_from_file_cpu is not None: post_frames_num = 1 # script 2 detail for end_latent
num_2x_frames_count = min(2, max(0, available_generated_latents - post_frames_num -1))
num_4x_frames_count = min(16, max(0, available_generated_latents - post_frames_num - num_2x_frames_count))
# Latent indexing for sample_hunyuan (from script 2)
latent_padding_size_for_indices = latent_padding_val * latent_window_size
pixel_frames_to_generate_this_step = latent_window_size * 4 - 3
indices_tensor_gpu = torch.arange(0,
clean_latent_pre_frames_num +
latent_padding_size_for_indices +
latent_window_size + # Note: script 2 uses latent_window_size here for `latent_indices` count
post_frames_num +
num_2x_frames_count +
num_4x_frames_count
).unsqueeze(0).to(cond_device)
(clean_latent_indices_pre_gpu,
blank_indices_gpu, # For padding
latent_indices_for_denoising_gpu, # For new generation
clean_latent_indices_post_gpu,
clean_latent_2x_indices_gpu,
clean_latent_4x_indices_gpu
) = indices_tensor_gpu.split(
[clean_latent_pre_frames_num, latent_padding_size_for_indices, latent_window_size,
post_frames_num, num_2x_frames_count, num_4x_frames_count], dim=1
)
clean_latent_indices_combined_gpu = torch.cat([clean_latent_indices_pre_gpu, clean_latent_indices_post_gpu], dim=1)
# Prepare conditioning latents
# clean_latents_pre_cpu: from end of input video
actual_pre_frames_to_take = min(clean_latent_pre_frames_num, video_latents_history_cpu.shape[2])
clean_latents_pre_cpu = video_latents_history_cpu[:, :, -actual_pre_frames_to_take:].clone()
if clean_latents_pre_cpu.shape[2] < clean_latent_pre_frames_num and clean_latents_pre_cpu.shape[2] > 0: # Pad if necessary
repeats = math.ceil(clean_latent_pre_frames_num / clean_latents_pre_cpu.shape[2])
clean_latents_pre_cpu = clean_latents_pre_cpu.repeat(1,1,repeats,1,1)[:,:,:clean_latent_pre_frames_num]
elif clean_latents_pre_cpu.shape[2] == 0 and clean_latent_pre_frames_num > 0: # Should not happen if video_latents_history_cpu is valid
clean_latents_pre_cpu = torch.zeros((1,channels_dim,clean_latent_pre_frames_num,latent_h,latent_w),dtype=torch.float32)
# clean_latents_post_cpu, _2x_cpu, _4x_cpu: from start of `history_latents_generated_cpu`
current_offset_in_generated = 0
# Post frames
actual_post_frames_to_take = min(post_frames_num, history_latents_generated_cpu.shape[2])
if is_end_of_extension and end_latent_from_file_cpu is not None:
clean_latents_post_cpu = end_latent_from_file_cpu.clone() # Should be [1,C,1,H,W]
else:
clean_latents_post_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_post_frames_to_take].clone()
current_offset_in_generated += clean_latents_post_cpu.shape[2]
if clean_latents_post_cpu.shape[2] < post_frames_num and clean_latents_post_cpu.shape[2] > 0: # Pad
repeats = math.ceil(post_frames_num / clean_latents_post_cpu.shape[2])
clean_latents_post_cpu = clean_latents_post_cpu.repeat(1,1,repeats,1,1)[:,:,:post_frames_num]
elif clean_latents_post_cpu.shape[2] == 0 and post_frames_num > 0: # Fill with zeros if no history and no end_latent
clean_latents_post_cpu = torch.zeros((1,channels_dim,post_frames_num,latent_h,latent_w),dtype=torch.float32)
# 2x frames
actual_2x_frames_to_take = min(num_2x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated)
clean_latents_2x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_2x_frames_to_take].clone()
current_offset_in_generated += clean_latents_2x_cpu.shape[2]
if clean_latents_2x_cpu.shape[2] < num_2x_frames_count and clean_latents_2x_cpu.shape[2] > 0: # Pad
repeats = math.ceil(num_2x_frames_count / clean_latents_2x_cpu.shape[2])
clean_latents_2x_cpu = clean_latents_2x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_2x_frames_count]
elif clean_latents_2x_cpu.shape[2] == 0 and num_2x_frames_count > 0:
clean_latents_2x_cpu = torch.zeros((1,channels_dim,num_2x_frames_count,latent_h,latent_w),dtype=torch.float32)
# 4x frames
actual_4x_frames_to_take = min(num_4x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated)
clean_latents_4x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_4x_frames_to_take].clone()
if clean_latents_4x_cpu.shape[2] < num_4x_frames_count and clean_latents_4x_cpu.shape[2] > 0: # Pad
repeats = math.ceil(num_4x_frames_count / clean_latents_4x_cpu.shape[2])
clean_latents_4x_cpu = clean_latents_4x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_4x_frames_count]
elif clean_latents_4x_cpu.shape[2] == 0 and num_4x_frames_count > 0:
clean_latents_4x_cpu = torch.zeros((1,channels_dim,num_4x_frames_count,latent_h,latent_w),dtype=torch.float32)
# Combine pre and post for `clean_latents` argument
clean_latents_for_sampler_gpu = torch.cat([
clean_latents_pre_cpu.to(device=cond_device, dtype=torch.float32),
clean_latents_post_cpu.to(device=cond_device, dtype=torch.float32)
], dim=2)
# Ensure 2x and 4x latents are None if their frame counts are 0
# The k_diffusion_hunyuan.sample_hunyuan and the DiT should handle None for these if indices are also empty.
clean_latents_2x_gpu = None
if num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] > 0:
clean_latents_2x_gpu = clean_latents_2x_cpu.to(device=cond_device, dtype=torch.float32)
elif num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] == 0: # Should have been filled with zeros if count > 0
print(f"Warning: num_2x_frames_count is {num_2x_frames_count} but clean_latents_2x_cpu is empty. Defaulting to None.")
clean_latents_4x_gpu = None
if num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] > 0:
clean_latents_4x_gpu = clean_latents_4x_cpu.to(device=cond_device, dtype=torch.float32)
elif num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] == 0:
print(f"Warning: num_4x_frames_count is {num_4x_frames_count} but clean_latents_4x_cpu is empty. Defaulting to None.")
# Also, ensure indices are None or empty if counts are zero.
# The split logic already ensures this if the split size is 0.
# clean_latent_2x_indices_gpu will be shape (B, 0) if num_2x_frames_count is 0.
# The DiT model should correctly interpret an empty indices tensor or None for the corresponding latent.
generated_latents_gpu_step = sample_hunyuan(
transformer=transformer, sampler='unipc', width=width, height=height,
frames=pixel_frames_to_generate_this_step, # Num frames for current chunk
real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs,
num_inference_steps=steps, generator=rnd,
prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler,
negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n,
device=cond_device, dtype=cond_dtype,
image_embeddings=image_embeddings_for_sampling_loop, # Use the blended/final one
latent_indices=latent_indices_for_denoising_gpu,
clean_latents=clean_latents_for_sampler_gpu,
clean_latent_indices=clean_latent_indices_combined_gpu,
clean_latents_2x=clean_latents_2x_gpu, # Can be None
clean_latent_2x_indices=clean_latent_2x_indices_gpu if num_2x_frames_count > 0 else None, # Pass None if count is 0
clean_latents_4x=clean_latents_4x_gpu, # Can be None
clean_latent_4x_indices=clean_latent_4x_indices_gpu if num_4x_frames_count > 0 else None, # Pass None if count is 0
callback=sampler_callback_cli,
)
if progress_bar_sampler: progress_bar_sampler.close()
# If this was the chunk closest to input video, prepend the last frame of input video for smoother transition
if is_start_of_extension:
generated_latents_gpu_step = torch.cat([
end_of_input_video_latent_cpu.to(generated_latents_gpu_step), # Use actual last frame latent
generated_latents_gpu_step
], dim=2)
# Prepend generated latents to history
history_latents_generated_cpu = torch.cat([generated_latents_gpu_step.cpu(), history_latents_generated_cpu], dim=2)
total_generated_latent_frames_count = history_latents_generated_cpu.shape[2]
# --- Decode and Append Pixels ---
target_vae_device = str(gpu if torch.cuda.is_available() else cpu)
if not high_vram:
if transformer: offload_model_from_device_for_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation)
if vae: load_model_as_complete(vae, target_device=target_vae_device)
else:
if vae: vae.to(target_vae_device)
# Decode the newly generated part (or a relevant segment for stitching)
# Script 2 decodes `real_history_latents[:, :, :section_latent_frames]`
# section_latent_frames = (latent_window_size * 2 + 1) if is_start_of_video else (latent_window_size * 2)
num_latents_to_decode_for_stitch = (latent_window_size * 2 + 1) if is_start_of_extension else (latent_window_size * 2)
num_latents_to_decode_for_stitch = min(num_latents_to_decode_for_stitch, history_latents_generated_cpu.shape[2])
latents_for_current_decode_gpu = history_latents_generated_cpu[:, :, :num_latents_to_decode_for_stitch].to(target_vae_device)
pixels_for_current_part_decoded_cpu = vae_decode(latents_for_current_decode_gpu, vae).cpu()
# Soft append pixels (current_pixels, history_pixels, overlap)
overlap_for_soft_append = latent_window_size * 4 - 3
if history_pixels_decoded_cpu is None:
history_pixels_decoded_cpu = pixels_for_current_part_decoded_cpu
else:
overlap_actual = min(overlap_for_soft_append, history_pixels_decoded_cpu.shape[2], pixels_for_current_part_decoded_cpu.shape[2])
if overlap_actual <=0: # Should not happen with proper windowing
history_pixels_decoded_cpu = torch.cat([pixels_for_current_part_decoded_cpu, history_pixels_decoded_cpu], dim=2) # Simple prepend
else:
history_pixels_decoded_cpu = soft_append_bcthw(
pixels_for_current_part_decoded_cpu, # Current (prepended)
history_pixels_decoded_cpu, # History
overlap=overlap_actual
)
if not high_vram:
if vae: unload_complete_models(vae)
if transformer and not is_start_of_extension : # Reload transformer for next iter
move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation)
# Save intermediate video
current_output_filename = os.path.join(outputs_folder, f'{job_id}_part{current_section_num_from_end}_totalframes{history_pixels_decoded_cpu.shape[2]}.mp4')
save_bcthw_as_mp4(history_pixels_decoded_cpu, current_output_filename, fps=fps, crf=mp4_crf)
print(f"MP4 Preview for section {current_section_num_from_end} saved: {current_output_filename}")
set_mp4_comments_imageio_ffmpeg(current_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}");
if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup):
try: os.remove(previous_video_path_for_cleanup)
except Exception as e_del: print(f"Error deleting {previous_video_path_for_cleanup}: {e_del}")
previous_video_path_for_cleanup = current_output_filename
if is_start_of_extension: # Last iteration of backward loop
break
# --- Final Video Assembly ---
if args.extension_only: # <<< Access args directly
print("Saving only the generated extension...")
# history_pixels_decoded_cpu already contains only the generated extension due to backward generation
# and how it's accumulated.
video_to_save_cpu = history_pixels_decoded_cpu
final_output_filename_suffix = "_extension_only_final.mp4"
final_log_message = "Final extension-only video saved:"
else:
print("Appending generated extension to the input video...")
# input_video_pixels_cpu is (1, C, F_in, H, W)
# history_pixels_decoded_cpu is (1, C, F_ext, H, W)
video_to_save_cpu = torch.cat([input_video_pixels_cpu, history_pixels_decoded_cpu], dim=2)
final_output_filename_suffix = "_final.mp4"
final_log_message = "Final extended video saved:"
final_output_filename = os.path.join(outputs_folder, f'{job_id}{final_output_filename_suffix}') # job_id already has _extonly if needed
save_bcthw_as_mp4(video_to_save_cpu, final_output_filename, fps=fps, crf=mp4_crf)
print(f"{final_log_message} {final_output_filename}")
set_mp4_comments_imageio_ffmpeg(final_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}");
if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup) and previous_video_path_for_cleanup != final_output_filename:
try: os.remove(previous_video_path_for_cleanup)
except Exception as e_del: print(f"Error deleting last part: {e_del}")
except Exception as e_outer:
traceback.print_exc()
print(f"Error during generation: {e_outer}")
finally:
if not high_vram:
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
print("--- Generation work cycle finished. ---")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="FramePack Video Generation CLI (with End Frame)")
# Inputs
parser.add_argument('--input_video', type=str, required=True, help='Path to the input video file.')
parser.add_argument('--prompt', type=str, required=True, help='Prompt for video generation.')
parser.add_argument('--n_prompt', type=str, default="", help='Negative prompt.')
parser.add_argument('--end_frame', type=str, default=None, help='Optional path to an image to guide the end of the video.')
parser.add_argument('--end_frame_weight', type=float, default=1.0, help='Weight for the end_frame image conditioning (0.0 to 1.0). Default 1.0.')
# Generation parameters
parser.add_argument('--seed', type=int, default=31337, help='Seed for generation.')
parser.add_argument('--resolution_max_dim', type=int, default=640, help='Target resolution (max width or height for bucket search).')
parser.add_argument('--total_second_length', type=float, default=5.0, help='Additional video length to generate (seconds).')
parser.add_argument('--latent_window_size', type=int, default=9, help='Latent window size (frames for DiT). Orignal FramePack default is 9.')
parser.add_argument('--steps', type=int, default=25, help='Number of inference steps.')
parser.add_argument('--cfg', type=float, default=1.0, help='CFG Scale. If > 1.0, n_prompt is used and gs is set to 1.0. Default 1.0 (for distilled guidance).')
parser.add_argument('--gs', type=float, default=10.0, help='Distilled CFG Scale (Embedded CFG for Original FramePack). Default 10.0.') # Original default
parser.add_argument('--rs', type=float, default=0.0, help='CFG Re-Scale (usually 0.0).')
parser.add_argument('--num_clean_frames', type=int, default=5, help='Number of 1x context frames for DiT conditioning. Script2 default 5.')
# Technical parameters
parser.add_argument('--gpu_memory_preservation', type=float, default=6.0, help='GPU memory to preserve (GB) for low VRAM mode.')
parser.add_argument('--use_teacache', action='store_true', default=False, help='Enable TeaCache (if DiT supports it).')
parser.add_argument('--no_resize', action='store_true', default=False, help='Force original video resolution for input video encoding (VAE).')
parser.add_argument('--mp4_crf', type=int, default=16, help='MP4 CRF value (0-51, lower is better quality).')
parser.add_argument('--vae_batch_size', type=int, default=-1, help='VAE batch size for input video encoding. Default: auto based on VRAM.')
parser.add_argument('--output_dir', type=str, default='./outputs/', help="Directory to save output videos.")
# Model paths
parser.add_argument('--dit', type=str, required=True, help="Path to local DiT model weights file or directory (e.g., for lllyasviel/FramePackI2V_HY).")
parser.add_argument('--vae', type=str, required=True, help="Path to local VAE model weights file or directory.")
parser.add_argument('--text_encoder1', type=str, required=True, help="Path to Text Encoder 1 (Llama) WEIGHT FILE.")
parser.add_argument('--text_encoder2', type=str, required=True, help="Path to Text Encoder 2 (CLIP) WEIGHT FILE.")
parser.add_argument('--image_encoder', type=str, required=True, help="Path to Image Encoder (SigLIP) WEIGHT FILE.")
# Advanced model settings
parser.add_argument('--attn_mode', type=str, default="torch", help="Attention mode for DiT (torch, flash, xformers, etc.).")
parser.add_argument('--fp8_llm', action='store_true', help="Use fp8 for Text Encoder 1 (Llama).") # from fpack_generate_video
parser.add_argument("--vae_chunk_size", type=int, default=None, help="Chunk size for CausalConv3d in VAE.")
parser.add_argument("--vae_spatial_tile_sample_min_size", type=int, default=None, help="Spatial tile sample min size for VAE.")
# LoRA
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path(s).")
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier(s).")
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns.")
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns.")
parser.add_argument('--extension_only', action='store_true', help="Save only the extension video without the input video attached.")
args = parser.parse_args()
current_device_str = str(gpu if torch.cuda.is_available() else cpu)
args.device = current_device_str
for model_arg_name in ['dit', 'vae', 'text_encoder1', 'text_encoder2', 'image_encoder']:
path_val = getattr(args, model_arg_name)
if not os.path.exists(path_val):
parser.error(f"Path for --{model_arg_name} not found: {path_val}")
outputs_folder = args.output_dir
os.makedirs(outputs_folder, exist_ok=True)
print(f"Outputting videos to: {outputs_folder}")
free_mem_gb = get_cuda_free_memory_gb(gpu if torch.cuda.is_available() else None)
# Adjusted high_vram threshold, can be tuned
high_vram = free_mem_gb > 30 # Example: 30GB+ for "high_vram"
print(f'Free VRAM {free_mem_gb:.2f} GB. High-VRAM Mode: {high_vram}')
if args.vae_batch_size == -1:
if free_mem_gb >= 18: args.vae_batch_size = 64
elif free_mem_gb >= 10: args.vae_batch_size = 32
else: args.vae_batch_size = 16
print(f"Auto-set VAE batch size to: {args.vae_batch_size}")
print("Loading models...")
loading_device_str = str(cpu) # Load to CPU first
transformer = load_packed_model(
device=loading_device_str,
dit_path=args.dit,
attn_mode=args.attn_mode,
loading_device=loading_device_str
)
print("DiT loaded.")
if args.lora_weight is not None and len(args.lora_weight) > 0:
print("Merging LoRA weights...")
if len(args.lora_multiplier) == 1 and len(args.lora_weight) > 1:
args.lora_multiplier = args.lora_multiplier * len(args.lora_weight)
elif len(args.lora_multiplier) != len(args.lora_weight):
parser.error(f"Number of LoRA weights ({len(args.lora_weight)}) and multipliers ({len(args.lora_multiplier)}) must match, or provide a single multiplier.")
try:
# Mimic fpack_generate_video.py's LoRA args structure if needed by merge_lora_weights
if not hasattr(args, 'lycoris'): args.lycoris = False
if not hasattr(args, 'save_merged_model'): args.save_merged_model = None
current_device_for_lora = torch.device(loading_device_str)
merge_lora_weights(lora_framepack, transformer, args, current_device_for_lora)
print("LoRA weights merged successfully.")
except Exception as e_lora:
print(f"Error merging LoRA weights: {e_lora}")
traceback.print_exc()
vae = load_vae(
vae_path=args.vae,
vae_chunk_size=args.vae_chunk_size,
vae_spatial_tile_sample_min_size=args.vae_spatial_tile_sample_min_size,
device=loading_device_str
)
print("VAE loaded.")
# For text_encoder loading, fpack_generate_video.py uses args.fp8_llm for text_encoder1
# The f1_video_cli_local.py passes `args` directly. We'll do the same.
tokenizer, text_encoder = load_text_encoder1(args, device=loading_device_str)
print("Text Encoder 1 and Tokenizer 1 loaded.")
tokenizer_2, text_encoder_2 = load_text_encoder2(args)
print("Text Encoder 2 and Tokenizer 2 loaded.")
feature_extractor, image_encoder = load_image_encoders(args)
print("Image Encoder and Feature Extractor loaded.")
all_models_list = [transformer, vae, text_encoder, text_encoder_2, image_encoder]
for model_obj in all_models_list:
if model_obj is not None:
model_obj.eval().requires_grad_(False)
# Set dtypes (Original FramePack typically bfloat16 for DiT, float16 for others)
if transformer: transformer.to(dtype=torch.bfloat16)
if vae: vae.to(dtype=torch.float16)
if image_encoder: image_encoder.to(dtype=torch.float16)
if text_encoder: text_encoder.to(dtype=torch.float16) # Or bfloat16 if fp8_llm implies that
if text_encoder_2: text_encoder_2.to(dtype=torch.float16)
if transformer:
transformer.high_quality_fp32_output_for_inference = True # Common setting
print('Transformer: high_quality_fp32_output_for_inference = True')
if vae and not high_vram:
vae.enable_slicing()
vae.enable_tiling()
target_gpu_device_str = str(gpu if torch.cuda.is_available() else cpu)
if not high_vram and torch.cuda.is_available():
print("Low VRAM mode: Setting up dynamic swapping for DiT and Text Encoder 1.")
if transformer: DynamicSwapInstaller.install_model(transformer, device=target_gpu_device_str)
if text_encoder: DynamicSwapInstaller.install_model(text_encoder, device=target_gpu_device_str)
# Other models (VAE, TE2, ImgEnc) will be loaded/offloaded as needed by `load_model_as_complete` / `unload_complete_models`
if vae: vae.to(cpu)
if text_encoder_2: text_encoder_2.to(cpu)
if image_encoder: image_encoder.to(cpu)
elif torch.cuda.is_available():
print(f"High VRAM mode: Moving all models to {target_gpu_device_str}.")
for model_obj in all_models_list:
if model_obj is not None: model_obj.to(target_gpu_device_str)
else:
print("Running on CPU. Models remain on CPU.")
print("All models loaded and configured.")
# Adjust gs if cfg > 1.0 (standard CFG mode)
actual_gs_cli = args.gs
if args.cfg > 1.0:
actual_gs_cli = 1.0 # For standard CFG, distilled guidance is turned off
print(f"CFG > 1.0 detected ({args.cfg}), this implies standard CFG. Overriding GS to 1.0 from {args.gs}.")
do_generation_work(
input_video_path=args.input_video,
prompt=args.prompt,
n_prompt=args.n_prompt,
seed=args.seed,
end_frame_path=args.end_frame,
end_frame_weight=args.end_frame_weight,
resolution_max_dim=args.resolution_max_dim,
additional_second_length=args.total_second_length,
latent_window_size=args.latent_window_size,
steps=args.steps,
cfg=args.cfg,
gs=actual_gs_cli,
rs=args.rs,
gpu_memory_preservation=args.gpu_memory_preservation,
use_teacache=args.use_teacache,
no_resize=args.no_resize,
mp4_crf=args.mp4_crf,
num_clean_frames=args.num_clean_frames,
vae_batch_size=args.vae_batch_size,
extension_only=args.extension_only
)
print("Video generation process completed.")