Spaces:
Running
Running
import os | |
import torch | |
import traceback | |
import einops | |
import numpy as np | |
import argparse | |
import math | |
import decord | |
from tqdm import tqdm | |
import pathlib | |
from datetime import datetime | |
import imageio_ffmpeg | |
import tempfile | |
import shutil | |
import subprocess | |
import sys | |
from PIL import Image | |
try: | |
from frame_pack.hunyuan_video_packed import load_packed_model | |
from frame_pack.framepack_utils import ( | |
load_vae, | |
load_text_encoder1, | |
load_text_encoder2, | |
load_image_encoders | |
) | |
from frame_pack.hunyuan import encode_prompt_conds, vae_decode, vae_encode # vae_decode_fake might be needed for previews if added | |
from frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp | |
from frame_pack.k_diffusion_hunyuan import sample_hunyuan | |
from frame_pack.clip_vision import hf_clip_vision_encode | |
from frame_pack.bucket_tools import find_nearest_bucket | |
from diffusers_helper.utils import save_bcthw_as_mp4 # from a common helper library | |
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, \ | |
move_model_to_device_with_memory_preservation, \ | |
offload_model_from_device_for_memory_preservation, \ | |
fake_diffusers_current_device, DynamicSwapInstaller, \ | |
unload_complete_models, load_model_as_complete | |
# For LoRA | |
from networks import lora_framepack | |
try: | |
from lycoris.kohya import create_network_from_weights | |
except ImportError: | |
pass # Lycoris optional | |
from base_wan_generate_video import merge_lora_weights # Assuming this is accessible | |
except ImportError as e: | |
print(f"Error importing FramePack related modules: {e}. Ensure they are in PYTHONPATH.") | |
sys.exit(1) | |
# --- Global Model Variables --- | |
text_encoder = None | |
text_encoder_2 = None | |
tokenizer = None | |
tokenizer_2 = None | |
vae = None | |
feature_extractor = None | |
image_encoder = None | |
transformer = None | |
high_vram = False | |
free_mem_gb = 0.0 | |
outputs_folder = './outputs/' # Default, can be overridden by --output_dir | |
def video_encode(video_path, resolution, no_resize, vae_model, vae_batch_size=16, device="cuda", width=None, height=None): | |
video_path = str(pathlib.Path(video_path).resolve()) | |
print(f"Processing video for encoding: {video_path}") | |
if device == "cuda" and not torch.cuda.is_available(): | |
print("CUDA is not available, falling back to CPU for video_encode") | |
device = "cpu" | |
try: | |
print("Initializing VideoReader...") | |
vr = decord.VideoReader(video_path) | |
fps = vr.get_avg_fps() | |
if fps == 0: | |
print("Warning: VideoReader reported FPS as 0. Attempting to get it via OpenCV.") | |
import cv2 | |
cap = cv2.VideoCapture(video_path) | |
fps_cv = cap.get(cv2.CAP_PROP_FPS) | |
cap.release() | |
if fps_cv > 0: | |
fps = fps_cv | |
print(f"Using FPS from OpenCV: {fps}") | |
else: | |
# Fallback FPS if all else fails | |
fps = 25 | |
print(f"Failed to determine FPS for the input video. Defaulting to {fps} FPS.") | |
num_real_frames = len(vr) | |
print(f"Video loaded: {num_real_frames} frames, FPS: {fps}") | |
latent_size_factor = 4 # Hunyuan VAE downsamples by 8, but generation often uses 4x frame groups | |
num_frames = (num_real_frames // latent_size_factor) * latent_size_factor | |
if num_frames != num_real_frames: | |
print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility (multiple of {latent_size_factor})") | |
if num_frames == 0: | |
raise ValueError(f"Video too short ({num_real_frames} frames) or becomes 0 after truncation. Needs at least {latent_size_factor} frames.") | |
num_real_frames = num_frames | |
print("Reading video frames...") | |
frames_np_all = vr.get_batch(range(num_real_frames)).asnumpy() | |
print(f"Frames read: {frames_np_all.shape}") | |
native_height, native_width = frames_np_all.shape[1], frames_np_all.shape[2] | |
print(f"Native video resolution: {native_width}x{native_height}") | |
target_h_arg = native_height if height is None else height | |
target_w_arg = native_width if width is None else width | |
if not no_resize: | |
actual_target_height, actual_target_width = find_nearest_bucket(target_h_arg, target_w_arg, resolution=resolution) | |
print(f"Adjusted resolution for VAE encoding: {actual_target_width}x{actual_target_height}") | |
else: | |
actual_target_width = (native_width // 8) * 8 | |
actual_target_height = (native_height // 8) * 8 | |
if actual_target_width != native_width or actual_target_height != native_height: | |
print(f"Using native resolution, adjusted to be divisible by 8: {actual_target_width}x{actual_target_height}") | |
else: | |
print(f"Using native resolution without resizing: {actual_target_width}x{actual_target_height}") | |
processed_frames_list = [] | |
for frame_idx in range(frames_np_all.shape[0]): | |
frame = frames_np_all[frame_idx] | |
frame_resized_np = resize_and_center_crop(frame, target_width=actual_target_width, target_height=actual_target_height) | |
processed_frames_list.append(frame_resized_np) | |
processed_frames_np_stack = np.stack(processed_frames_list) | |
print(f"Frames preprocessed: {processed_frames_np_stack.shape}") | |
input_image_np_for_clip_first = processed_frames_np_stack[0] | |
input_image_np_for_clip_last = processed_frames_np_stack[-1] | |
print("Converting frames to tensor...") | |
frames_pt = torch.from_numpy(processed_frames_np_stack).float() / 127.5 - 1.0 | |
frames_pt = frames_pt.permute(0, 3, 1, 2) # B, H, W, C -> B, C, H, W | |
frames_pt = frames_pt.unsqueeze(0).permute(0, 2, 1, 3, 4) # B, C, H, W -> 1, C, B, H, W (as VAE expects 1,C,F,H,W) | |
print(f"Tensor shape for VAE: {frames_pt.shape}") | |
input_video_pixels_cpu = frames_pt.clone().cpu() | |
print(f"Moving VAE and tensor to device: {device}") | |
vae_model.to(device) | |
frames_pt = frames_pt.to(device) | |
print(f"Encoding input video frames with VAE (batch size: {vae_batch_size})") | |
all_latents_list = [] | |
vae_model.eval() | |
with torch.no_grad(): | |
for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="VAE Encoding Video Frames", mininterval=0.1): | |
batch_frames_pt = frames_pt[:, :, i:i + vae_batch_size] | |
try: | |
batch_latents = vae_encode(batch_frames_pt, vae_model) | |
all_latents_list.append(batch_latents.cpu()) | |
except RuntimeError as e: | |
print(f"Error during VAE encoding: {str(e)}") | |
if "out of memory" in str(e).lower() and device == "cuda": | |
print("CUDA out of memory during VAE encoding. Try reducing --vae_batch_size or use CPU for VAE.") | |
raise | |
history_latents_cpu = torch.cat(all_latents_list, dim=2) | |
print(f"History latents shape (original video): {history_latents_cpu.shape}") | |
start_latent_cpu = history_latents_cpu[:, :, :1].clone() | |
end_of_input_video_latent_cpu = history_latents_cpu[:, :, -1:].clone() | |
print(f"Start latent shape (for conditioning): {start_latent_cpu.shape}") | |
print(f"End of input video latent shape: {end_of_input_video_latent_cpu.shape}") | |
if device == "cuda": | |
vae_model.to(cpu) # Move VAE back to CPU | |
torch.cuda.empty_cache() | |
print("VAE moved back to CPU, CUDA cache cleared") | |
return (start_latent_cpu, input_image_np_for_clip_first, | |
history_latents_cpu, fps, | |
actual_target_height, actual_target_width, | |
input_video_pixels_cpu, | |
end_of_input_video_latent_cpu, input_image_np_for_clip_last) | |
except Exception as e: | |
print(f"Error in video_encode: {str(e)}") | |
traceback.print_exc() | |
raise | |
def image_encode(image_np, target_width, target_height, vae_model, image_encoder_model, feature_extractor_model, device="cuda"): | |
""" | |
Encode a single image into a latent and compute its CLIP vision embedding. | |
""" | |
global high_vram # Use global high_vram status | |
print("Processing single image for encoding (e.g., end_frame)...") | |
try: | |
print(f"Using target resolution for image encoding: {target_width}x{target_height}") | |
processed_image_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height) | |
image_pt = torch.from_numpy(processed_image_np).float() / 127.5 - 1.0 | |
image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).unsqueeze(2) # N C F H W (N=1, F=1) | |
target_vae_device = device | |
if not high_vram: load_model_as_complete(vae_model, target_device=target_vae_device) | |
else: vae_model.to(target_vae_device) | |
image_pt_device = image_pt.to(target_vae_device) | |
latent = vae_encode(image_pt_device, vae_model).cpu() # Encode and move to CPU | |
print(f"Single image VAE output shape (latent): {latent.shape}") | |
if not high_vram: unload_complete_models(vae_model) # Offload VAE if low VRAM | |
target_img_enc_device = device | |
if not high_vram: load_model_as_complete(image_encoder_model, target_device=target_img_enc_device) | |
else: image_encoder_model.to(target_img_enc_device) | |
clip_embedding_output = hf_clip_vision_encode(processed_image_np, feature_extractor_model, image_encoder_model) | |
clip_embedding = clip_embedding_output.last_hidden_state.cpu() # Encode and move to CPU | |
print(f"Single image CLIP embedding shape: {clip_embedding.shape}") | |
if not high_vram: unload_complete_models(image_encoder_model) # Offload image encoder if low VRAM | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
# print("CUDA cache cleared after single image_encode") | |
return latent, clip_embedding, processed_image_np | |
except Exception as e: | |
print(f"Error in image_encode: {str(e)}") | |
traceback.print_exc() | |
raise | |
def set_mp4_comments_imageio_ffmpeg(input_file, comments): | |
try: | |
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() | |
if not os.path.exists(input_file): | |
print(f"Error: Input file {input_file} does not exist") | |
return False | |
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
command = [ | |
ffmpeg_path, '-i', input_file, '-metadata', f'comment={comments}', | |
'-c:v', 'copy', '-c:a', 'copy', '-y', temp_file | |
] | |
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False) | |
if result.returncode == 0: | |
shutil.move(temp_file, input_file) | |
print(f"Successfully added comments to {input_file}") | |
return True | |
else: | |
if os.path.exists(temp_file): os.remove(temp_file) | |
print(f"Error: FFmpeg failed with message:\n{result.stderr}") | |
return False | |
except Exception as e: | |
if 'temp_file' in locals() and os.path.exists(temp_file): os.remove(temp_file) | |
print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e)) | |
return False | |
def do_generation_work( | |
input_video_path, prompt, n_prompt, seed, | |
end_frame_path, end_frame_weight, # New arguments | |
resolution_max_dim, | |
additional_second_length, | |
latent_window_size, steps, cfg, gs, rs, | |
gpu_memory_preservation, use_teacache, no_resize, mp4_crf, | |
num_clean_frames, vae_batch_size, | |
extension_only | |
): | |
global high_vram, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, feature_extractor, image_encoder, transformer, args | |
print('--- Starting Video Generation (with End Frame support) ---') | |
try: | |
# --- Text Encoding --- | |
print('Text encoding...') | |
target_text_enc_device = str(gpu if torch.cuda.is_available() else cpu) | |
if not high_vram: | |
if text_encoder: fake_diffusers_current_device(text_encoder, target_text_enc_device) # DynamicSwapInstaller for text_encoder | |
if text_encoder_2: load_model_as_complete(text_encoder_2, target_device=target_text_enc_device) | |
else: | |
if text_encoder: text_encoder.to(target_text_enc_device) | |
if text_encoder_2: text_encoder_2.to(target_text_enc_device) | |
llama_vec_gpu, clip_l_pooler_gpu = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) | |
if cfg == 1.0: # Note: Original FramePack usually uses gs, cfg=1 means gs is active | |
llama_vec_n_gpu, clip_l_pooler_n_gpu = torch.zeros_like(llama_vec_gpu), torch.zeros_like(clip_l_pooler_gpu) | |
else: # If cfg > 1.0, it implies standard CFG, so n_prompt is used. gs should be 1.0 in this case. | |
llama_vec_n_gpu, clip_l_pooler_n_gpu = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) | |
# Store on CPU | |
llama_vec_padded_cpu, llama_attention_mask_cpu = crop_or_pad_yield_mask(llama_vec_gpu.cpu(), length=512) | |
llama_vec_n_padded_cpu, llama_attention_mask_n_cpu = crop_or_pad_yield_mask(llama_vec_n_gpu.cpu(), length=512) | |
clip_l_pooler_cpu = clip_l_pooler_gpu.cpu() | |
clip_l_pooler_n_cpu = clip_l_pooler_n_gpu.cpu() | |
if not high_vram: unload_complete_models(text_encoder_2) # text_encoder is managed by DynamicSwap | |
# --- Video and End Frame Encoding --- | |
print('Encoding input video...') | |
video_encode_device = str(gpu if torch.cuda.is_available() else cpu) | |
(start_latent_input_cpu, input_image_np_first, | |
video_latents_history_cpu, fps, height, width, | |
input_video_pixels_cpu, | |
end_of_input_video_latent_cpu, input_image_np_last) = video_encode( | |
input_video_path, resolution_max_dim, no_resize, vae, | |
vae_batch_size=vae_batch_size, device=video_encode_device, | |
width=None, height=None # video_encode will use resolution_max_dim | |
) | |
if fps <= 0: raise ValueError("FPS from input video is 0 or invalid.") | |
end_latent_from_file_cpu, end_clip_embedding_from_file_cpu = None, None | |
if end_frame_path: | |
print(f"Encoding provided end frame from: {end_frame_path}") | |
end_frame_pil = Image.open(end_frame_path).convert("RGB") | |
end_frame_np = np.array(end_frame_pil) | |
end_latent_from_file_cpu, end_clip_embedding_from_file_cpu, _ = image_encode( | |
end_frame_np, target_width=width, target_height=height, | |
vae_model=vae, image_encoder_model=image_encoder, | |
feature_extractor_model=feature_extractor, device=video_encode_device | |
) | |
# --- CLIP Vision Encoding for first and last frames of input video --- | |
print('CLIP Vision encoding for input video frames...') | |
target_img_enc_device = str(gpu if torch.cuda.is_available() else cpu) | |
if not high_vram: load_model_as_complete(image_encoder, target_device=target_img_enc_device) | |
else: image_encoder.to(target_img_enc_device) | |
# For original FramePack, image_embeddings in sample_hunyuan often comes from the *start* image. | |
# Script 2 uses end_of_input_video_embedding or a blend with the explicit end_frame. | |
# We will follow script 2 for conditioning. | |
# start_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_first, feature_extractor, image_encoder).last_hidden_state.cpu() | |
end_of_input_video_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_last, feature_extractor, image_encoder).last_hidden_state.cpu() | |
if not high_vram: unload_complete_models(image_encoder) | |
# Determine final image embedding for sampling loop | |
if end_clip_embedding_from_file_cpu is not None: | |
print(f"Blending end-of-input-video embedding with provided end_frame embedding (weight: {end_frame_weight})") | |
final_clip_embedding_for_sampling_cpu = \ | |
(1.0 - end_frame_weight) * end_of_input_video_clip_embedding_cpu + \ | |
end_frame_weight * end_clip_embedding_from_file_cpu | |
else: | |
print("Using end-of-input-video's last frame embedding for image conditioning.") | |
final_clip_embedding_for_sampling_cpu = end_of_input_video_clip_embedding_cpu.clone() | |
# --- Prepare for Sampling Loop --- | |
target_transformer_device = str(gpu if torch.cuda.is_available() else cpu) | |
if not high_vram: | |
if transformer: move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) | |
else: | |
if transformer: transformer.to(target_transformer_device) | |
cond_device = transformer.device | |
cond_dtype = transformer.dtype | |
# Move conditioning tensors to transformer's device and dtype | |
llama_vec = llama_vec_padded_cpu.to(device=cond_device, dtype=cond_dtype) | |
llama_attention_mask = llama_attention_mask_cpu.to(device=cond_device) # Mask is usually bool/int | |
clip_l_pooler = clip_l_pooler_cpu.to(device=cond_device, dtype=cond_dtype) | |
llama_vec_n = llama_vec_n_padded_cpu.to(device=cond_device, dtype=cond_dtype) | |
llama_attention_mask_n = llama_attention_mask_n_cpu.to(device=cond_device) | |
clip_l_pooler_n = clip_l_pooler_n_cpu.to(device=cond_device, dtype=cond_dtype) | |
# This is the image embedding that will be used in the sampling loop | |
image_embeddings_for_sampling_loop = final_clip_embedding_for_sampling_cpu.to(device=cond_device, dtype=cond_dtype) | |
# start_latent_for_initial_cond_gpu is the first frame of input video, used for clean_latents_pre | |
# However, script 2 uses `video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):]` for clean_latents_pre. | |
# And `start_latent` for sample_hunyuan's `clean_latents` is `torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)` | |
# For backward generation, the "start_latent" concept for `sample_hunyuan`'s `clean_latents` argument | |
# is often the *last frame of the input video* when generating the chunk closest to the input video. | |
# Let's use end_of_input_video_latent_cpu for this role when appropriate. | |
num_output_pixel_frames_per_section = latent_window_size * 4 # Not -3 here, as this is for total section calc | |
if num_output_pixel_frames_per_section == 0: | |
raise ValueError("latent_window_size * 4 is zero, cannot calculate total_extension_latent_sections.") | |
total_extension_latent_sections = int(max(round((additional_second_length * fps) / num_output_pixel_frames_per_section), 1)) | |
print(f"Input video FPS: {fps}, Target additional length: {additional_second_length}s") | |
print(f"Generating {total_extension_latent_sections} new sections for extension (approx {total_extension_latent_sections * num_output_pixel_frames_per_section / fps:.2f}s).") | |
job_id_base = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + \ | |
f"_framepack-vidEndFrm_{width}x{height}_{additional_second_length:.1f}s_seed{seed}_s{steps}_gs{gs}_cfg{cfg}" | |
job_id = job_id_base | |
if args.extension_only: # <<< Access args directly | |
job_id += "_extonly" | |
print("Extension-only mode enabled. Filenames will reflect this.") | |
rnd = torch.Generator("cpu").manual_seed(seed) | |
# Initialize history for generated latents (starts empty or with end_latent_from_file) | |
if end_latent_from_file_cpu is not None: | |
# This assumes end_latent_from_file_cpu is [1,C,1,H,W], we might need more frames if it's a seed | |
# Script 2's logic for clean_latents_post when is_end_of_video seems to use just 1 frame. | |
history_latents_generated_cpu = end_latent_from_file_cpu.clone() | |
else: | |
channels_dim = video_latents_history_cpu.shape[1] # Get from input video latents | |
latent_h, latent_w = height // 8, width // 8 | |
history_latents_generated_cpu = torch.empty((1, channels_dim, 0, latent_h, latent_w), dtype=torch.float32, device='cpu') | |
# Initialize history for decoded pixels (starts empty) | |
history_pixels_decoded_cpu = None | |
total_generated_latent_frames_count = history_latents_generated_cpu.shape[2] | |
previous_video_path_for_cleanup = None | |
# Backward generation loop (from demo_gradio_video+endframe.py) | |
latent_paddings = list(reversed(range(total_extension_latent_sections))) | |
if total_extension_latent_sections > 4: # Heuristic from script 2 | |
latent_paddings = [3] + [2] * (total_extension_latent_sections - 3) + [1, 0] | |
for loop_idx, latent_padding_val in enumerate(latent_paddings): | |
current_section_num_from_end = loop_idx + 1 | |
is_start_of_extension = (latent_padding_val == 0) # This is the chunk closest to input video | |
is_end_of_extension = (latent_padding_val == latent_paddings[0]) # This is the chunk furthest from input video | |
print(f"--- Generating Extension: Seed {seed}: Section {current_section_num_from_end}/{total_extension_latent_sections} (backward), padding={latent_padding_val} ---") | |
if transformer: transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps if use_teacache else 0) | |
progress_bar_sampler = tqdm(total=steps, desc=f"Sampling Extension Section {current_section_num_from_end}/{total_extension_latent_sections}", file=sys.stdout, dynamic_ncols=True) | |
def sampler_callback_cli(d): progress_bar_sampler.update(1) | |
# Context frame calculation (from demo_gradio_video+endframe.py worker) | |
# `available_frames` for context refers to previously *generated* frames or input video frames | |
# For `clean_latents_pre`, it's always from `video_latents_history_cpu` | |
# For `clean_latents_post`, `_2x`, `_4x`, it's from `history_latents_generated_cpu` | |
effective_clean_frames_count = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1 | |
# For clean_latents_pre (from input video) | |
# If is_start_of_extension, we might want stronger anchoring to input video. Script 2 uses full `effective_clean_frames_count`. | |
clean_latent_pre_frames_num = effective_clean_frames_count | |
if is_start_of_extension: # Closest to input video | |
clean_latent_pre_frames_num = 1 # Script 2 uses 1 to avoid jumpcuts from input video when generating chunk closest to it. | |
# For clean_latents_post, _2x, _4x (from previously generated extension chunks) | |
available_generated_latents = history_latents_generated_cpu.shape[2] | |
# `post_frames_num` is for clean_latents_post | |
post_frames_num = 1 if is_end_of_extension and end_latent_from_file_cpu is not None else effective_clean_frames_count | |
if is_end_of_extension and end_latent_from_file_cpu is not None: post_frames_num = 1 # script 2 detail for end_latent | |
num_2x_frames_count = min(2, max(0, available_generated_latents - post_frames_num -1)) | |
num_4x_frames_count = min(16, max(0, available_generated_latents - post_frames_num - num_2x_frames_count)) | |
# Latent indexing for sample_hunyuan (from script 2) | |
latent_padding_size_for_indices = latent_padding_val * latent_window_size | |
pixel_frames_to_generate_this_step = latent_window_size * 4 - 3 | |
indices_tensor_gpu = torch.arange(0, | |
clean_latent_pre_frames_num + | |
latent_padding_size_for_indices + | |
latent_window_size + # Note: script 2 uses latent_window_size here for `latent_indices` count | |
post_frames_num + | |
num_2x_frames_count + | |
num_4x_frames_count | |
).unsqueeze(0).to(cond_device) | |
(clean_latent_indices_pre_gpu, | |
blank_indices_gpu, # For padding | |
latent_indices_for_denoising_gpu, # For new generation | |
clean_latent_indices_post_gpu, | |
clean_latent_2x_indices_gpu, | |
clean_latent_4x_indices_gpu | |
) = indices_tensor_gpu.split( | |
[clean_latent_pre_frames_num, latent_padding_size_for_indices, latent_window_size, | |
post_frames_num, num_2x_frames_count, num_4x_frames_count], dim=1 | |
) | |
clean_latent_indices_combined_gpu = torch.cat([clean_latent_indices_pre_gpu, clean_latent_indices_post_gpu], dim=1) | |
# Prepare conditioning latents | |
# clean_latents_pre_cpu: from end of input video | |
actual_pre_frames_to_take = min(clean_latent_pre_frames_num, video_latents_history_cpu.shape[2]) | |
clean_latents_pre_cpu = video_latents_history_cpu[:, :, -actual_pre_frames_to_take:].clone() | |
if clean_latents_pre_cpu.shape[2] < clean_latent_pre_frames_num and clean_latents_pre_cpu.shape[2] > 0: # Pad if necessary | |
repeats = math.ceil(clean_latent_pre_frames_num / clean_latents_pre_cpu.shape[2]) | |
clean_latents_pre_cpu = clean_latents_pre_cpu.repeat(1,1,repeats,1,1)[:,:,:clean_latent_pre_frames_num] | |
elif clean_latents_pre_cpu.shape[2] == 0 and clean_latent_pre_frames_num > 0: # Should not happen if video_latents_history_cpu is valid | |
clean_latents_pre_cpu = torch.zeros((1,channels_dim,clean_latent_pre_frames_num,latent_h,latent_w),dtype=torch.float32) | |
# clean_latents_post_cpu, _2x_cpu, _4x_cpu: from start of `history_latents_generated_cpu` | |
current_offset_in_generated = 0 | |
# Post frames | |
actual_post_frames_to_take = min(post_frames_num, history_latents_generated_cpu.shape[2]) | |
if is_end_of_extension and end_latent_from_file_cpu is not None: | |
clean_latents_post_cpu = end_latent_from_file_cpu.clone() # Should be [1,C,1,H,W] | |
else: | |
clean_latents_post_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_post_frames_to_take].clone() | |
current_offset_in_generated += clean_latents_post_cpu.shape[2] | |
if clean_latents_post_cpu.shape[2] < post_frames_num and clean_latents_post_cpu.shape[2] > 0: # Pad | |
repeats = math.ceil(post_frames_num / clean_latents_post_cpu.shape[2]) | |
clean_latents_post_cpu = clean_latents_post_cpu.repeat(1,1,repeats,1,1)[:,:,:post_frames_num] | |
elif clean_latents_post_cpu.shape[2] == 0 and post_frames_num > 0: # Fill with zeros if no history and no end_latent | |
clean_latents_post_cpu = torch.zeros((1,channels_dim,post_frames_num,latent_h,latent_w),dtype=torch.float32) | |
# 2x frames | |
actual_2x_frames_to_take = min(num_2x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated) | |
clean_latents_2x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_2x_frames_to_take].clone() | |
current_offset_in_generated += clean_latents_2x_cpu.shape[2] | |
if clean_latents_2x_cpu.shape[2] < num_2x_frames_count and clean_latents_2x_cpu.shape[2] > 0: # Pad | |
repeats = math.ceil(num_2x_frames_count / clean_latents_2x_cpu.shape[2]) | |
clean_latents_2x_cpu = clean_latents_2x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_2x_frames_count] | |
elif clean_latents_2x_cpu.shape[2] == 0 and num_2x_frames_count > 0: | |
clean_latents_2x_cpu = torch.zeros((1,channels_dim,num_2x_frames_count,latent_h,latent_w),dtype=torch.float32) | |
# 4x frames | |
actual_4x_frames_to_take = min(num_4x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated) | |
clean_latents_4x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_4x_frames_to_take].clone() | |
if clean_latents_4x_cpu.shape[2] < num_4x_frames_count and clean_latents_4x_cpu.shape[2] > 0: # Pad | |
repeats = math.ceil(num_4x_frames_count / clean_latents_4x_cpu.shape[2]) | |
clean_latents_4x_cpu = clean_latents_4x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_4x_frames_count] | |
elif clean_latents_4x_cpu.shape[2] == 0 and num_4x_frames_count > 0: | |
clean_latents_4x_cpu = torch.zeros((1,channels_dim,num_4x_frames_count,latent_h,latent_w),dtype=torch.float32) | |
# Combine pre and post for `clean_latents` argument | |
clean_latents_for_sampler_gpu = torch.cat([ | |
clean_latents_pre_cpu.to(device=cond_device, dtype=torch.float32), | |
clean_latents_post_cpu.to(device=cond_device, dtype=torch.float32) | |
], dim=2) | |
# Ensure 2x and 4x latents are None if their frame counts are 0 | |
# The k_diffusion_hunyuan.sample_hunyuan and the DiT should handle None for these if indices are also empty. | |
clean_latents_2x_gpu = None | |
if num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] > 0: | |
clean_latents_2x_gpu = clean_latents_2x_cpu.to(device=cond_device, dtype=torch.float32) | |
elif num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] == 0: # Should have been filled with zeros if count > 0 | |
print(f"Warning: num_2x_frames_count is {num_2x_frames_count} but clean_latents_2x_cpu is empty. Defaulting to None.") | |
clean_latents_4x_gpu = None | |
if num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] > 0: | |
clean_latents_4x_gpu = clean_latents_4x_cpu.to(device=cond_device, dtype=torch.float32) | |
elif num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] == 0: | |
print(f"Warning: num_4x_frames_count is {num_4x_frames_count} but clean_latents_4x_cpu is empty. Defaulting to None.") | |
# Also, ensure indices are None or empty if counts are zero. | |
# The split logic already ensures this if the split size is 0. | |
# clean_latent_2x_indices_gpu will be shape (B, 0) if num_2x_frames_count is 0. | |
# The DiT model should correctly interpret an empty indices tensor or None for the corresponding latent. | |
generated_latents_gpu_step = sample_hunyuan( | |
transformer=transformer, sampler='unipc', width=width, height=height, | |
frames=pixel_frames_to_generate_this_step, # Num frames for current chunk | |
real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs, | |
num_inference_steps=steps, generator=rnd, | |
prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, | |
negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, | |
device=cond_device, dtype=cond_dtype, | |
image_embeddings=image_embeddings_for_sampling_loop, # Use the blended/final one | |
latent_indices=latent_indices_for_denoising_gpu, | |
clean_latents=clean_latents_for_sampler_gpu, | |
clean_latent_indices=clean_latent_indices_combined_gpu, | |
clean_latents_2x=clean_latents_2x_gpu, # Can be None | |
clean_latent_2x_indices=clean_latent_2x_indices_gpu if num_2x_frames_count > 0 else None, # Pass None if count is 0 | |
clean_latents_4x=clean_latents_4x_gpu, # Can be None | |
clean_latent_4x_indices=clean_latent_4x_indices_gpu if num_4x_frames_count > 0 else None, # Pass None if count is 0 | |
callback=sampler_callback_cli, | |
) | |
if progress_bar_sampler: progress_bar_sampler.close() | |
# If this was the chunk closest to input video, prepend the last frame of input video for smoother transition | |
if is_start_of_extension: | |
generated_latents_gpu_step = torch.cat([ | |
end_of_input_video_latent_cpu.to(generated_latents_gpu_step), # Use actual last frame latent | |
generated_latents_gpu_step | |
], dim=2) | |
# Prepend generated latents to history | |
history_latents_generated_cpu = torch.cat([generated_latents_gpu_step.cpu(), history_latents_generated_cpu], dim=2) | |
total_generated_latent_frames_count = history_latents_generated_cpu.shape[2] | |
# --- Decode and Append Pixels --- | |
target_vae_device = str(gpu if torch.cuda.is_available() else cpu) | |
if not high_vram: | |
if transformer: offload_model_from_device_for_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) | |
if vae: load_model_as_complete(vae, target_device=target_vae_device) | |
else: | |
if vae: vae.to(target_vae_device) | |
# Decode the newly generated part (or a relevant segment for stitching) | |
# Script 2 decodes `real_history_latents[:, :, :section_latent_frames]` | |
# section_latent_frames = (latent_window_size * 2 + 1) if is_start_of_video else (latent_window_size * 2) | |
num_latents_to_decode_for_stitch = (latent_window_size * 2 + 1) if is_start_of_extension else (latent_window_size * 2) | |
num_latents_to_decode_for_stitch = min(num_latents_to_decode_for_stitch, history_latents_generated_cpu.shape[2]) | |
latents_for_current_decode_gpu = history_latents_generated_cpu[:, :, :num_latents_to_decode_for_stitch].to(target_vae_device) | |
pixels_for_current_part_decoded_cpu = vae_decode(latents_for_current_decode_gpu, vae).cpu() | |
# Soft append pixels (current_pixels, history_pixels, overlap) | |
overlap_for_soft_append = latent_window_size * 4 - 3 | |
if history_pixels_decoded_cpu is None: | |
history_pixels_decoded_cpu = pixels_for_current_part_decoded_cpu | |
else: | |
overlap_actual = min(overlap_for_soft_append, history_pixels_decoded_cpu.shape[2], pixels_for_current_part_decoded_cpu.shape[2]) | |
if overlap_actual <=0: # Should not happen with proper windowing | |
history_pixels_decoded_cpu = torch.cat([pixels_for_current_part_decoded_cpu, history_pixels_decoded_cpu], dim=2) # Simple prepend | |
else: | |
history_pixels_decoded_cpu = soft_append_bcthw( | |
pixels_for_current_part_decoded_cpu, # Current (prepended) | |
history_pixels_decoded_cpu, # History | |
overlap=overlap_actual | |
) | |
if not high_vram: | |
if vae: unload_complete_models(vae) | |
if transformer and not is_start_of_extension : # Reload transformer for next iter | |
move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) | |
# Save intermediate video | |
current_output_filename = os.path.join(outputs_folder, f'{job_id}_part{current_section_num_from_end}_totalframes{history_pixels_decoded_cpu.shape[2]}.mp4') | |
save_bcthw_as_mp4(history_pixels_decoded_cpu, current_output_filename, fps=fps, crf=mp4_crf) | |
print(f"MP4 Preview for section {current_section_num_from_end} saved: {current_output_filename}") | |
set_mp4_comments_imageio_ffmpeg(current_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}"); | |
if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup): | |
try: os.remove(previous_video_path_for_cleanup) | |
except Exception as e_del: print(f"Error deleting {previous_video_path_for_cleanup}: {e_del}") | |
previous_video_path_for_cleanup = current_output_filename | |
if is_start_of_extension: # Last iteration of backward loop | |
break | |
# --- Final Video Assembly --- | |
if args.extension_only: # <<< Access args directly | |
print("Saving only the generated extension...") | |
# history_pixels_decoded_cpu already contains only the generated extension due to backward generation | |
# and how it's accumulated. | |
video_to_save_cpu = history_pixels_decoded_cpu | |
final_output_filename_suffix = "_extension_only_final.mp4" | |
final_log_message = "Final extension-only video saved:" | |
else: | |
print("Appending generated extension to the input video...") | |
# input_video_pixels_cpu is (1, C, F_in, H, W) | |
# history_pixels_decoded_cpu is (1, C, F_ext, H, W) | |
video_to_save_cpu = torch.cat([input_video_pixels_cpu, history_pixels_decoded_cpu], dim=2) | |
final_output_filename_suffix = "_final.mp4" | |
final_log_message = "Final extended video saved:" | |
final_output_filename = os.path.join(outputs_folder, f'{job_id}{final_output_filename_suffix}') # job_id already has _extonly if needed | |
save_bcthw_as_mp4(video_to_save_cpu, final_output_filename, fps=fps, crf=mp4_crf) | |
print(f"{final_log_message} {final_output_filename}") | |
set_mp4_comments_imageio_ffmpeg(final_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}"); | |
if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup) and previous_video_path_for_cleanup != final_output_filename: | |
try: os.remove(previous_video_path_for_cleanup) | |
except Exception as e_del: print(f"Error deleting last part: {e_del}") | |
except Exception as e_outer: | |
traceback.print_exc() | |
print(f"Error during generation: {e_outer}") | |
finally: | |
if not high_vram: | |
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) | |
print("--- Generation work cycle finished. ---") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="FramePack Video Generation CLI (with End Frame)") | |
# Inputs | |
parser.add_argument('--input_video', type=str, required=True, help='Path to the input video file.') | |
parser.add_argument('--prompt', type=str, required=True, help='Prompt for video generation.') | |
parser.add_argument('--n_prompt', type=str, default="", help='Negative prompt.') | |
parser.add_argument('--end_frame', type=str, default=None, help='Optional path to an image to guide the end of the video.') | |
parser.add_argument('--end_frame_weight', type=float, default=1.0, help='Weight for the end_frame image conditioning (0.0 to 1.0). Default 1.0.') | |
# Generation parameters | |
parser.add_argument('--seed', type=int, default=31337, help='Seed for generation.') | |
parser.add_argument('--resolution_max_dim', type=int, default=640, help='Target resolution (max width or height for bucket search).') | |
parser.add_argument('--total_second_length', type=float, default=5.0, help='Additional video length to generate (seconds).') | |
parser.add_argument('--latent_window_size', type=int, default=9, help='Latent window size (frames for DiT). Orignal FramePack default is 9.') | |
parser.add_argument('--steps', type=int, default=25, help='Number of inference steps.') | |
parser.add_argument('--cfg', type=float, default=1.0, help='CFG Scale. If > 1.0, n_prompt is used and gs is set to 1.0. Default 1.0 (for distilled guidance).') | |
parser.add_argument('--gs', type=float, default=10.0, help='Distilled CFG Scale (Embedded CFG for Original FramePack). Default 10.0.') # Original default | |
parser.add_argument('--rs', type=float, default=0.0, help='CFG Re-Scale (usually 0.0).') | |
parser.add_argument('--num_clean_frames', type=int, default=5, help='Number of 1x context frames for DiT conditioning. Script2 default 5.') | |
# Technical parameters | |
parser.add_argument('--gpu_memory_preservation', type=float, default=6.0, help='GPU memory to preserve (GB) for low VRAM mode.') | |
parser.add_argument('--use_teacache', action='store_true', default=False, help='Enable TeaCache (if DiT supports it).') | |
parser.add_argument('--no_resize', action='store_true', default=False, help='Force original video resolution for input video encoding (VAE).') | |
parser.add_argument('--mp4_crf', type=int, default=16, help='MP4 CRF value (0-51, lower is better quality).') | |
parser.add_argument('--vae_batch_size', type=int, default=-1, help='VAE batch size for input video encoding. Default: auto based on VRAM.') | |
parser.add_argument('--output_dir', type=str, default='./outputs/', help="Directory to save output videos.") | |
# Model paths | |
parser.add_argument('--dit', type=str, required=True, help="Path to local DiT model weights file or directory (e.g., for lllyasviel/FramePackI2V_HY).") | |
parser.add_argument('--vae', type=str, required=True, help="Path to local VAE model weights file or directory.") | |
parser.add_argument('--text_encoder1', type=str, required=True, help="Path to Text Encoder 1 (Llama) WEIGHT FILE.") | |
parser.add_argument('--text_encoder2', type=str, required=True, help="Path to Text Encoder 2 (CLIP) WEIGHT FILE.") | |
parser.add_argument('--image_encoder', type=str, required=True, help="Path to Image Encoder (SigLIP) WEIGHT FILE.") | |
# Advanced model settings | |
parser.add_argument('--attn_mode', type=str, default="torch", help="Attention mode for DiT (torch, flash, xformers, etc.).") | |
parser.add_argument('--fp8_llm', action='store_true', help="Use fp8 for Text Encoder 1 (Llama).") # from fpack_generate_video | |
parser.add_argument("--vae_chunk_size", type=int, default=None, help="Chunk size for CausalConv3d in VAE.") | |
parser.add_argument("--vae_spatial_tile_sample_min_size", type=int, default=None, help="Spatial tile sample min size for VAE.") | |
# LoRA | |
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path(s).") | |
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier(s).") | |
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns.") | |
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns.") | |
parser.add_argument('--extension_only', action='store_true', help="Save only the extension video without the input video attached.") | |
args = parser.parse_args() | |
current_device_str = str(gpu if torch.cuda.is_available() else cpu) | |
args.device = current_device_str | |
for model_arg_name in ['dit', 'vae', 'text_encoder1', 'text_encoder2', 'image_encoder']: | |
path_val = getattr(args, model_arg_name) | |
if not os.path.exists(path_val): | |
parser.error(f"Path for --{model_arg_name} not found: {path_val}") | |
outputs_folder = args.output_dir | |
os.makedirs(outputs_folder, exist_ok=True) | |
print(f"Outputting videos to: {outputs_folder}") | |
free_mem_gb = get_cuda_free_memory_gb(gpu if torch.cuda.is_available() else None) | |
# Adjusted high_vram threshold, can be tuned | |
high_vram = free_mem_gb > 30 # Example: 30GB+ for "high_vram" | |
print(f'Free VRAM {free_mem_gb:.2f} GB. High-VRAM Mode: {high_vram}') | |
if args.vae_batch_size == -1: | |
if free_mem_gb >= 18: args.vae_batch_size = 64 | |
elif free_mem_gb >= 10: args.vae_batch_size = 32 | |
else: args.vae_batch_size = 16 | |
print(f"Auto-set VAE batch size to: {args.vae_batch_size}") | |
print("Loading models...") | |
loading_device_str = str(cpu) # Load to CPU first | |
transformer = load_packed_model( | |
device=loading_device_str, | |
dit_path=args.dit, | |
attn_mode=args.attn_mode, | |
loading_device=loading_device_str | |
) | |
print("DiT loaded.") | |
if args.lora_weight is not None and len(args.lora_weight) > 0: | |
print("Merging LoRA weights...") | |
if len(args.lora_multiplier) == 1 and len(args.lora_weight) > 1: | |
args.lora_multiplier = args.lora_multiplier * len(args.lora_weight) | |
elif len(args.lora_multiplier) != len(args.lora_weight): | |
parser.error(f"Number of LoRA weights ({len(args.lora_weight)}) and multipliers ({len(args.lora_multiplier)}) must match, or provide a single multiplier.") | |
try: | |
# Mimic fpack_generate_video.py's LoRA args structure if needed by merge_lora_weights | |
if not hasattr(args, 'lycoris'): args.lycoris = False | |
if not hasattr(args, 'save_merged_model'): args.save_merged_model = None | |
current_device_for_lora = torch.device(loading_device_str) | |
merge_lora_weights(lora_framepack, transformer, args, current_device_for_lora) | |
print("LoRA weights merged successfully.") | |
except Exception as e_lora: | |
print(f"Error merging LoRA weights: {e_lora}") | |
traceback.print_exc() | |
vae = load_vae( | |
vae_path=args.vae, | |
vae_chunk_size=args.vae_chunk_size, | |
vae_spatial_tile_sample_min_size=args.vae_spatial_tile_sample_min_size, | |
device=loading_device_str | |
) | |
print("VAE loaded.") | |
# For text_encoder loading, fpack_generate_video.py uses args.fp8_llm for text_encoder1 | |
# The f1_video_cli_local.py passes `args` directly. We'll do the same. | |
tokenizer, text_encoder = load_text_encoder1(args, device=loading_device_str) | |
print("Text Encoder 1 and Tokenizer 1 loaded.") | |
tokenizer_2, text_encoder_2 = load_text_encoder2(args) | |
print("Text Encoder 2 and Tokenizer 2 loaded.") | |
feature_extractor, image_encoder = load_image_encoders(args) | |
print("Image Encoder and Feature Extractor loaded.") | |
all_models_list = [transformer, vae, text_encoder, text_encoder_2, image_encoder] | |
for model_obj in all_models_list: | |
if model_obj is not None: | |
model_obj.eval().requires_grad_(False) | |
# Set dtypes (Original FramePack typically bfloat16 for DiT, float16 for others) | |
if transformer: transformer.to(dtype=torch.bfloat16) | |
if vae: vae.to(dtype=torch.float16) | |
if image_encoder: image_encoder.to(dtype=torch.float16) | |
if text_encoder: text_encoder.to(dtype=torch.float16) # Or bfloat16 if fp8_llm implies that | |
if text_encoder_2: text_encoder_2.to(dtype=torch.float16) | |
if transformer: | |
transformer.high_quality_fp32_output_for_inference = True # Common setting | |
print('Transformer: high_quality_fp32_output_for_inference = True') | |
if vae and not high_vram: | |
vae.enable_slicing() | |
vae.enable_tiling() | |
target_gpu_device_str = str(gpu if torch.cuda.is_available() else cpu) | |
if not high_vram and torch.cuda.is_available(): | |
print("Low VRAM mode: Setting up dynamic swapping for DiT and Text Encoder 1.") | |
if transformer: DynamicSwapInstaller.install_model(transformer, device=target_gpu_device_str) | |
if text_encoder: DynamicSwapInstaller.install_model(text_encoder, device=target_gpu_device_str) | |
# Other models (VAE, TE2, ImgEnc) will be loaded/offloaded as needed by `load_model_as_complete` / `unload_complete_models` | |
if vae: vae.to(cpu) | |
if text_encoder_2: text_encoder_2.to(cpu) | |
if image_encoder: image_encoder.to(cpu) | |
elif torch.cuda.is_available(): | |
print(f"High VRAM mode: Moving all models to {target_gpu_device_str}.") | |
for model_obj in all_models_list: | |
if model_obj is not None: model_obj.to(target_gpu_device_str) | |
else: | |
print("Running on CPU. Models remain on CPU.") | |
print("All models loaded and configured.") | |
# Adjust gs if cfg > 1.0 (standard CFG mode) | |
actual_gs_cli = args.gs | |
if args.cfg > 1.0: | |
actual_gs_cli = 1.0 # For standard CFG, distilled guidance is turned off | |
print(f"CFG > 1.0 detected ({args.cfg}), this implies standard CFG. Overriding GS to 1.0 from {args.gs}.") | |
do_generation_work( | |
input_video_path=args.input_video, | |
prompt=args.prompt, | |
n_prompt=args.n_prompt, | |
seed=args.seed, | |
end_frame_path=args.end_frame, | |
end_frame_weight=args.end_frame_weight, | |
resolution_max_dim=args.resolution_max_dim, | |
additional_second_length=args.total_second_length, | |
latent_window_size=args.latent_window_size, | |
steps=args.steps, | |
cfg=args.cfg, | |
gs=actual_gs_cli, | |
rs=args.rs, | |
gpu_memory_preservation=args.gpu_memory_preservation, | |
use_teacache=args.use_teacache, | |
no_resize=args.no_resize, | |
mp4_crf=args.mp4_crf, | |
num_clean_frames=args.num_clean_frames, | |
vae_batch_size=args.vae_batch_size, | |
extension_only=args.extension_only | |
) | |
print("Video generation process completed.") |