Spaces:
Running
Running
import argparse | |
from datetime import datetime | |
from pathlib import Path | |
import random | |
import sys | |
import os | |
import time | |
from typing import Optional, Union | |
import numpy as np | |
import torch | |
import torchvision | |
import accelerate | |
from diffusers.utils.torch_utils import randn_tensor | |
from transformers.models.llama import LlamaModel | |
from tqdm import tqdm | |
import av | |
from einops import rearrange | |
from safetensors.torch import load_file, save_file | |
from safetensors import safe_open | |
from PIL import Image | |
from hunyuan_model import vae | |
from hunyuan_model.text_encoder import TextEncoder | |
from hunyuan_model.text_encoder import PROMPT_TEMPLATE | |
from hunyuan_model.vae import load_vae | |
from hunyuan_model.models import load_transformer, get_rotary_pos_embed | |
from hunyuan_model.fp8_optimization import convert_fp8_linear | |
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler | |
from networks import lora | |
try: | |
from lycoris.kohya import create_network_from_weights | |
except: | |
pass | |
from utils.model_utils import str_to_dtype | |
from utils.safetensors_utils import mem_eff_save_file | |
from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def clean_memory_on_device(device): | |
if device.type == "cuda": | |
torch.cuda.empty_cache() | |
elif device.type == "cpu": | |
pass | |
elif device.type == "mps": # not tested | |
torch.mps.empty_cache() | |
def synchronize_device(device: torch.device): | |
if device.type == "cuda": | |
torch.cuda.synchronize() | |
elif device.type == "xpu": | |
torch.xpu.synchronize() | |
elif device.type == "mps": | |
torch.mps.synchronize() | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): | |
"""save videos by video tensor | |
copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 | |
Args: | |
videos (torch.Tensor): video tensor predicted by the model | |
path (str): path to save video | |
rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. | |
n_rows (int, optional): Defaults to 1. | |
fps (int, optional): video save fps. Defaults to 8. | |
""" | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = torch.clamp(x, 0, 1) | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
# # save video with av | |
# container = av.open(path, "w") | |
# stream = container.add_stream("libx264", rate=fps) | |
# for x in outputs: | |
# frame = av.VideoFrame.from_ndarray(x, format="rgb24") | |
# packet = stream.encode(frame) | |
# container.mux(packet) | |
# packet = stream.encode(None) | |
# container.mux(packet) | |
# container.close() | |
height, width, _ = outputs[0].shape | |
# create output container | |
container = av.open(path, mode="w") | |
# create video stream | |
codec = "libx264" | |
pixel_format = "yuv420p" | |
stream = container.add_stream(codec, rate=fps) | |
stream.width = width | |
stream.height = height | |
stream.pix_fmt = pixel_format | |
stream.bit_rate = 4000000 # 4Mbit/s | |
for frame_array in outputs: | |
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") | |
packets = stream.encode(frame) | |
for packet in packets: | |
container.mux(packet) | |
for packet in stream.encode(): | |
container.mux(packet) | |
container.close() | |
def save_images_grid( | |
videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True | |
): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = torch.clamp(x, 0, 1) | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
if create_subdir: | |
output_dir = os.path.join(parent_dir, image_name) | |
else: | |
output_dir = parent_dir | |
os.makedirs(output_dir, exist_ok=True) | |
for i, x in enumerate(outputs): | |
image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png") | |
image = Image.fromarray(x) | |
image.save(image_path) | |
# region Encoding prompt | |
def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder): | |
r""" | |
Encodes the prompt into text encoder hidden states. | |
Args: | |
prompt (`str` or `List[str]`): | |
prompt to be encoded | |
device: (`torch.device`): | |
torch device | |
num_videos_per_prompt (`int`): | |
number of videos that should be generated per prompt | |
text_encoder (TextEncoder): | |
text encoder to be used for encoding the prompt | |
""" | |
# LoRA and Textual Inversion are not supported in this script | |
# negative prompt and prompt embedding are not supported in this script | |
# clip_skip is not supported in this script because it is not used in the original script | |
data_type = "video" # video only, image is not supported | |
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) | |
with torch.no_grad(): | |
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) | |
prompt_embeds = prompt_outputs.hidden_state | |
attention_mask = prompt_outputs.attention_mask | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device) | |
bs_embed, seq_len = attention_mask.shape | |
attention_mask = attention_mask.repeat(1, num_videos_per_prompt) | |
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) | |
prompt_embeds_dtype = text_encoder.dtype | |
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | |
if prompt_embeds.ndim == 2: | |
bs_embed, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) | |
else: | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) | |
return prompt_embeds, attention_mask | |
def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None): | |
# constants | |
prompt_template_video = "dit-llm-encode-video" | |
prompt_template = "dit-llm-encode" | |
text_encoder_dtype = torch.float16 | |
text_encoder_type = "llm" | |
text_len = 256 | |
hidden_state_skip_layer = 2 | |
apply_final_norm = False | |
reproduce = False | |
text_encoder_2_type = "clipL" | |
text_len_2 = 77 | |
num_videos = 1 | |
# if args.prompt_template_video is not None: | |
# crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) | |
# elif args.prompt_template is not None: | |
# crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) | |
# else: | |
# crop_start = 0 | |
crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0) | |
max_length = text_len + crop_start | |
# prompt_template | |
prompt_template = PROMPT_TEMPLATE[prompt_template] | |
# prompt_template_video | |
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None | |
# load text encoders | |
logger.info(f"loading text encoder: {args.text_encoder1}") | |
text_encoder = TextEncoder( | |
text_encoder_type=text_encoder_type, | |
max_length=max_length, | |
text_encoder_dtype=text_encoder_dtype, | |
text_encoder_path=args.text_encoder1, | |
tokenizer_type=text_encoder_type, | |
prompt_template=prompt_template, | |
prompt_template_video=prompt_template_video, | |
hidden_state_skip_layer=hidden_state_skip_layer, | |
apply_final_norm=apply_final_norm, | |
reproduce=reproduce, | |
) | |
text_encoder.eval() | |
if fp8_llm: | |
org_dtype = text_encoder.dtype | |
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") | |
text_encoder.to(device=device, dtype=torch.float8_e4m3fn) | |
# prepare LLM for fp8 | |
def prepare_fp8(llama_model: LlamaModel, target_dtype): | |
def forward_hook(module): | |
def forward(hidden_states): | |
input_dtype = hidden_states.dtype | |
hidden_states = hidden_states.to(torch.float32) | |
variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) | |
return module.weight.to(input_dtype) * hidden_states.to(input_dtype) | |
return forward | |
for module in llama_model.modules(): | |
if module.__class__.__name__ in ["Embedding"]: | |
# print("set", module.__class__.__name__, "to", target_dtype) | |
module.to(target_dtype) | |
if module.__class__.__name__ in ["LlamaRMSNorm"]: | |
# print("set", module.__class__.__name__, "hooks") | |
module.forward = forward_hook(module) | |
prepare_fp8(text_encoder.model, org_dtype) | |
logger.info(f"loading text encoder 2: {args.text_encoder2}") | |
text_encoder_2 = TextEncoder( | |
text_encoder_type=text_encoder_2_type, | |
max_length=text_len_2, | |
text_encoder_dtype=text_encoder_dtype, | |
text_encoder_path=args.text_encoder2, | |
tokenizer_type=text_encoder_2_type, | |
reproduce=reproduce, | |
) | |
text_encoder_2.eval() | |
# encode prompt | |
logger.info(f"Encoding prompt with text encoder 1") | |
text_encoder.to(device=device) | |
if fp8_llm: | |
with accelerator.autocast(): | |
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder) | |
else: | |
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder) | |
text_encoder = None | |
clean_memory_on_device(device) | |
logger.info(f"Encoding prompt with text encoder 2") | |
text_encoder_2.to(device=device) | |
prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2) | |
prompt_embeds = prompt_embeds.to("cpu") | |
prompt_mask = prompt_mask.to("cpu") | |
prompt_embeds_2 = prompt_embeds_2.to("cpu") | |
prompt_mask_2 = prompt_mask_2.to("cpu") | |
text_encoder_2 = None | |
clean_memory_on_device(device) | |
return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 | |
# endregion | |
def prepare_vae(args, device): | |
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) | |
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) | |
vae.eval() | |
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} | |
# set chunk_size to CausalConv3d recursively | |
chunk_size = args.vae_chunk_size | |
if chunk_size is not None: | |
vae.set_chunk_size_for_causal_conv_3d(chunk_size) | |
logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") | |
if args.vae_spatial_tile_sample_min_size is not None: | |
vae.enable_spatial_tiling(True) | |
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size | |
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 | |
# elif args.vae_tiling: | |
else: | |
vae.enable_spatial_tiling(True) | |
return vae, vae_dtype | |
def encode_to_latents(args, video, device): | |
vae, vae_dtype = prepare_vae(args, device) | |
video = video.to(device=device, dtype=vae_dtype) | |
video = video * 2 - 1 # 0, 1 -> -1, 1 | |
with torch.no_grad(): | |
latents = vae.encode(video).latent_dist.sample() | |
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: | |
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor | |
else: | |
latents = latents * vae.config.scaling_factor | |
return latents | |
def decode_latents(args, latents, device): | |
vae, vae_dtype = prepare_vae(args, device) | |
expand_temporal_dim = False | |
if len(latents.shape) == 4: | |
latents = latents.unsqueeze(2) | |
expand_temporal_dim = True | |
elif len(latents.shape) == 5: | |
pass | |
else: | |
raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.") | |
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: | |
latents = latents / vae.config.scaling_factor + vae.config.shift_factor | |
else: | |
latents = latents / vae.config.scaling_factor | |
latents = latents.to(device=device, dtype=vae_dtype) | |
with torch.no_grad(): | |
image = vae.decode(latents, return_dict=False)[0] | |
if expand_temporal_dim: | |
image = image.squeeze(2) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | |
image = image.cpu().float() | |
return image | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="HunyuanVideo inference script") | |
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory") | |
parser.add_argument( | |
"--dit_in_channels", | |
type=int, | |
default=None, | |
help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others", | |
) | |
parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory") | |
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") | |
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory") | |
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory") | |
# LoRA | |
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") | |
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") | |
parser.add_argument( | |
"--save_merged_model", | |
type=str, | |
default=None, | |
help="Save merged model to path. If specified, no inference will be performed.", | |
) | |
parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights") | |
# inference | |
parser.add_argument("--prompt", type=str, required=True, help="prompt for generation") | |
parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation") | |
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size") | |
parser.add_argument("--video_length", type=int, default=129, help="video length") | |
parser.add_argument("--fps", type=int, default=24, help="video fps") | |
parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps") | |
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") | |
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") | |
parser.add_argument( | |
"--guidance_scale", | |
type=float, | |
default=1.0, | |
help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)", | |
) | |
parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.") | |
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") | |
parser.add_argument( | |
"--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model" | |
) | |
parser.add_argument( | |
"--split_uncond", | |
action="store_true", | |
help="split unconditional call for classifier free guidance, slower but less memory usage", | |
) | |
parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference") | |
# Flow Matching | |
parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.") | |
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") | |
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") | |
parser.add_argument( | |
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" | |
) | |
parser.add_argument( | |
"--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode" | |
) | |
parser.add_argument( | |
"--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True" | |
) | |
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") | |
parser.add_argument( | |
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" | |
) | |
parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model") | |
parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu") | |
parser.add_argument( | |
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" | |
) | |
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") | |
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") | |
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") | |
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)") | |
parser.add_argument("--compile", action="store_true", help="Enable torch.compile") | |
parser.add_argument( | |
"--compile_args", | |
nargs=4, | |
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), | |
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], | |
help="Torch.compile settings", | |
) | |
args = parser.parse_args() | |
assert (args.latent_path is None or len(args.latent_path) == 0) or ( | |
args.output_type == "images" or args.output_type == "video" | |
), "latent_path is only supported for images or video output" | |
# update dit_weight based on model_base if not exists | |
if args.fp8_fast and not args.fp8: | |
raise ValueError("--fp8_fast requires --fp8") | |
return args | |
def check_inputs(args): | |
height = args.video_size[0] | |
width = args.video_size[1] | |
video_length = args.video_length | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
return height, width, video_length | |
def main(): | |
args = parse_args() | |
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" | |
device = torch.device(device) | |
dit_dtype = torch.bfloat16 | |
dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype | |
logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}") | |
original_base_names = None | |
if args.latent_path is not None and len(args.latent_path) > 0: | |
original_base_names = [] | |
latents_list = [] | |
seeds = [] | |
for latent_path in args.latent_path: | |
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) | |
seed = 0 | |
if os.path.splitext(latent_path)[1] != ".safetensors": | |
latents = torch.load(latent_path, map_location="cpu") | |
else: | |
latents = load_file(latent_path)["latent"] | |
with safe_open(latent_path, framework="pt") as f: | |
metadata = f.metadata() | |
if metadata is None: | |
metadata = {} | |
logger.info(f"Loaded metadata: {metadata}") | |
if "seeds" in metadata: | |
seed = int(metadata["seeds"]) | |
seeds.append(seed) | |
latents_list.append(latents) | |
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") | |
latents = torch.stack(latents_list, dim=0) | |
else: | |
# prepare accelerator | |
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" | |
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) | |
# load prompt | |
prompt = args.prompt # TODO load prompts from file | |
assert prompt is not None, "prompt is required" | |
# check inputs: may be height, width, video_length etc will be changed for each generation in future | |
height, width, video_length = check_inputs(args) | |
# encode prompt with LLM and Text Encoder | |
logger.info(f"Encoding prompt: {prompt}") | |
do_classifier_free_guidance = args.guidance_scale != 1.0 | |
if do_classifier_free_guidance: | |
negative_prompt = args.negative_prompt | |
if negative_prompt is None: | |
logger.info("Negative prompt is not provided, using empty prompt") | |
negative_prompt = "" | |
logger.info(f"Encoding negative prompt: {negative_prompt}") | |
prompt = [negative_prompt, prompt] | |
else: | |
if args.negative_prompt is not None: | |
logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.") | |
prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt( | |
prompt, args, device, args.fp8_llm, accelerator | |
) | |
# encode latents for video2video inference | |
video_latents = None | |
if args.video_path is not None: | |
# v2v inference | |
logger.info(f"Video2Video inference: {args.video_path}") | |
video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames | |
if len(video) < video_length: | |
raise ValueError(f"Video length is less than {video_length}") | |
video = np.stack(video, axis=0) # F, H, W, C | |
video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W | |
video = video / 255.0 | |
logger.info(f"Encoding video to latents") | |
video_latents = encode_to_latents(args, video, device) | |
video_latents = video_latents.to(device=device, dtype=dit_dtype) | |
clean_memory_on_device(device) | |
# encode latents for image2video inference | |
image_latents = None | |
if args.image_path is not None: | |
# i2v inference | |
logger.info(f"Image2Video inference: {args.image_path}") | |
image = Image.open(args.image_path) | |
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array | |
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W | |
image = image / 255.0 | |
logger.info(f"Encoding image to latents") | |
image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W | |
image_latents = image_latents.to(device=device, dtype=dit_dtype) | |
clean_memory_on_device(device) | |
# load DiT model | |
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 | |
loading_device = "cpu" # if blocks_to_swap > 0 else device | |
logger.info(f"Loading DiT model from {args.dit}") | |
if args.attn_mode == "sdpa": | |
args.attn_mode = "torch" | |
# if image_latents is given, the model should be I2V model, so the in_channels should be 32 | |
dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16) | |
# if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16 | |
# the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway | |
# on the fly merging will be a solution for this issue for .safetenors files (not implemented yet) | |
transformer = load_transformer( | |
args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels | |
) | |
transformer.eval() | |
# load LoRA weights | |
if args.lora_weight is not None and len(args.lora_weight) > 0: | |
for i, lora_weight in enumerate(args.lora_weight): | |
if args.lora_multiplier is not None and len(args.lora_multiplier) > i: | |
lora_multiplier = args.lora_multiplier[i] | |
else: | |
lora_multiplier = 1.0 | |
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") | |
weights_sd = load_file(lora_weight) | |
# Filter to exclude keys that are part of single_blocks | |
if args.exclude_single_blocks: | |
filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k} | |
weights_sd = filtered_weights | |
if args.lycoris: | |
lycoris_net, _ = create_network_from_weights( | |
multiplier=lora_multiplier, | |
file=None, | |
weights_sd=weights_sd, | |
unet=transformer, | |
text_encoder=None, | |
vae=None, | |
for_inference=True, | |
) | |
else: | |
network = lora.create_arch_network_from_weights( | |
lora_multiplier, weights_sd, unet=transformer, for_inference=True | |
) | |
logger.info("Merging LoRA weights to DiT model") | |
# try: | |
# network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) | |
# info = network.load_state_dict(weights_sd, strict=True) | |
# logger.info(f"Loaded LoRA weights from {weights_file}: {info}") | |
# network.eval() | |
# network.to(device) | |
# except Exception as e: | |
if args.lycoris: | |
lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device) | |
else: | |
network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True) | |
synchronize_device(device) | |
logger.info("LoRA weights loaded") | |
# save model here before casting to dit_weight_dtype | |
if args.save_merged_model: | |
logger.info(f"Saving merged model to {args.save_merged_model}") | |
mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory | |
logger.info("Merged model saved") | |
return | |
logger.info(f"Casting model to {dit_weight_dtype}") | |
transformer.to(dtype=dit_weight_dtype) | |
if args.fp8_fast: | |
logger.info("Enabling FP8 acceleration") | |
params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} | |
for name, param in transformer.named_parameters(): | |
dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype | |
param.to(dtype=dtype_to_use) | |
convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep) | |
if args.compile: | |
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args | |
logger.info( | |
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" | |
) | |
torch._dynamo.config.cache_size_limit = 32 | |
for i, block in enumerate(transformer.single_blocks): | |
compiled_block = torch.compile( | |
block, | |
backend=compile_backend, | |
mode=compile_mode, | |
dynamic=compile_dynamic.lower() in "true", | |
fullgraph=compile_fullgraph.lower() in "true", | |
) | |
transformer.single_blocks[i] = compiled_block | |
for i, block in enumerate(transformer.double_blocks): | |
compiled_block = torch.compile( | |
block, | |
backend=compile_backend, | |
mode=compile_mode, | |
dynamic=compile_dynamic.lower() in "true", | |
fullgraph=compile_fullgraph.lower() in "true", | |
) | |
transformer.double_blocks[i] = compiled_block | |
if blocks_to_swap > 0: | |
logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}") | |
transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False) | |
transformer.move_to_device_except_swap_blocks(device) | |
transformer.prepare_block_swap_before_forward() | |
else: | |
logger.info(f"Moving model to {device}") | |
transformer.to(device=device) | |
if args.img_in_txt_in_offloading: | |
logger.info("Enable offloading img_in and txt_in to CPU") | |
transformer.enable_img_in_txt_in_offloading() | |
# load scheduler | |
logger.info(f"Loading scheduler") | |
scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler") | |
# Prepare timesteps | |
num_inference_steps = args.infer_steps | |
scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler | |
timesteps = scheduler.timesteps | |
# Prepare generator | |
num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size | |
seed = args.seed | |
if seed is None: | |
seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)] | |
elif isinstance(seed, int): | |
seeds = [seed + i for i in range(num_videos_per_prompt)] | |
else: | |
raise ValueError(f"Seed must be an integer or None, got {seed}.") | |
generator = [torch.Generator(device).manual_seed(seed) for seed in seeds] | |
# Prepare noisy latents | |
num_channels_latents = 16 # transformer.config.in_channels | |
vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4 | |
vae_ver = vae.VAE_VER | |
if "884" in vae_ver: | |
latent_video_length = (video_length - 1) // 4 + 1 | |
elif "888" in vae_ver: | |
latent_video_length = (video_length - 1) // 8 + 1 | |
else: | |
latent_video_length = video_length | |
# shape = ( | |
# num_videos_per_prompt, | |
# num_channels_latents, | |
# latent_video_length, | |
# height // vae_scale_factor, | |
# width // vae_scale_factor, | |
# ) | |
# latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype) | |
# make first N frames to be the same if the given seed is same | |
shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor) | |
latents = [] | |
for i in range(latent_video_length): | |
latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype)) | |
latents = torch.cat(latents, dim=2) | |
# pad image_latents to match the length of video_latents | |
if image_latents is not None: | |
zero_latents = torch.zeros_like(latents) | |
zero_latents[:, :, :1, :, :] = image_latents | |
image_latents = zero_latents | |
if args.video_path is not None: | |
# v2v inference | |
noise = latents | |
assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}" | |
num_inference_steps = int(num_inference_steps * args.strength) | |
timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time | |
t = timestep_start / 1000.0 | |
latents = noise * t + video_latents * (1 - t) | |
timesteps = timesteps[-num_inference_steps:] | |
logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}") | |
# FlowMatchDiscreteScheduler does not have init_noise_sigma | |
# Denoising loop | |
embedded_guidance_scale = args.embedded_cfg_scale | |
if embedded_guidance_scale is not None: | |
guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu") | |
guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype) | |
if do_classifier_free_guidance: | |
guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0) | |
else: | |
guidance_expand = None | |
freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width) | |
# n_tokens = freqs_cos.shape[0] | |
# move and cast all inputs to the correct device and dtype | |
prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype) | |
prompt_mask = prompt_mask.to(device=device) | |
prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype) | |
prompt_mask_2 = prompt_mask_2.to(device=device) | |
freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype) | |
freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype) | |
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference | |
# assert split_uncond and split_attn | |
if args.split_attn and do_classifier_free_guidance and not args.split_uncond: | |
logger.warning("split_attn is enabled, split_uncond will be enabled as well.") | |
args.split_uncond = True | |
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p: | |
with tqdm(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
latents = scheduler.scale_model_input(latents, t) | |
# predict the noise residual | |
with torch.no_grad(), accelerator.autocast(): | |
latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0) | |
if image_latents is not None: | |
latents_image_input = ( | |
image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0) | |
) | |
latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W | |
batch_size = 1 if args.split_uncond else latents_input.shape[0] | |
noise_pred_list = [] | |
for j in range(0, latents_input.shape[0], batch_size): | |
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256) | |
latents_input[j : j + batch_size], # [1, 16, 33, 24, 42] | |
t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1] | |
text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096] | |
text_mask=prompt_mask[j : j + batch_size], # [1, 256] | |
text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768] | |
freqs_cos=freqs_cos, # [seqlen, head_dim] | |
freqs_sin=freqs_sin, # [seqlen, head_dim] | |
guidance=guidance_expand[j : j + batch_size], # [1] | |
return_dict=True, | |
)["x"] | |
noise_pred_list.append(noise_pred) | |
noise_pred = torch.cat(noise_pred_list, dim=0) | |
# perform classifier free guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
# # SkyReels' rescale noise config is omitted for now | |
# if guidance_rescale > 0.0: | |
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
# noise_pred = rescale_noise_cfg( | |
# noise_pred, | |
# noise_pred_cond, | |
# guidance_rescale=self.guidance_rescale, | |
# ) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
# update progress bar | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): | |
if progress_bar is not None: | |
progress_bar.update() | |
# print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) | |
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
latents = latents.detach().cpu() | |
transformer = None | |
clean_memory_on_device(device) | |
# Save samples | |
output_type = args.output_type | |
save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}" | |
os.makedirs(save_path, exist_ok=True) | |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
if output_type == "latent" or output_type == "both": | |
# save latent | |
for i, latent in enumerate(latents): | |
latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors" | |
if args.no_metadata: | |
metadata = None | |
else: | |
metadata = { | |
"seeds": f"{seeds[i]}", | |
"prompt": f"{args.prompt}", | |
"height": f"{height}", | |
"width": f"{width}", | |
"video_length": f"{video_length}", | |
"infer_steps": f"{num_inference_steps}", | |
"guidance_scale": f"{args.guidance_scale}", | |
"embedded_cfg_scale": f"{args.embedded_cfg_scale}", | |
} | |
if args.negative_prompt is not None: | |
metadata["negative_prompt"] = f"{args.negative_prompt}" | |
sd = {"latent": latent} | |
save_file(sd, latent_path, metadata=metadata) | |
logger.info(f"Latent save to: {latent_path}") | |
if output_type == "video" or output_type == "both": | |
# save video | |
videos = decode_latents(args, latents, device) | |
for i, sample in enumerate(videos): | |
original_name = "" if original_base_names is None else f"_{original_base_names[i]}" | |
sample = sample.unsqueeze(0) | |
video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4" | |
save_videos_grid(sample, video_path, fps=args.fps) | |
logger.info(f"Sample save to: {video_path}") | |
elif output_type == "images": | |
# save images | |
videos = decode_latents(args, latents, device) | |
for i, sample in enumerate(videos): | |
original_name = "" if original_base_names is None else f"_{original_base_names[i]}" | |
sample = sample.unsqueeze(0) | |
image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}" | |
save_images_grid(sample, save_path, image_name) | |
logger.info(f"Sample images save to: {save_path}/{image_name}") | |
logger.info("Done!") | |
if __name__ == "__main__": | |
main() | |