# original code: https://github.com/lllyasviel/FramePack # original license: Apache-2.0 import glob import math import numbers import os from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union import torch import einops import torch.nn as nn import torch.nn.functional as F import numpy as np from modules.custom_offloading_utils import ModelOffloader from utils.safetensors_utils import load_split_weights from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8 from accelerate import init_empty_weights try: # raise NotImplementedError from xformers.ops import memory_efficient_attention as xformers_attn_func print("Xformers is installed!") except: print("Xformers is not installed!") xformers_attn_func = None try: # raise NotImplementedError from flash_attn import flash_attn_varlen_func, flash_attn_func print("Flash Attn is installed!") except: print("Flash Attn is not installed!") flash_attn_varlen_func = None flash_attn_func = None try: # raise NotImplementedError from sageattention import sageattn_varlen, sageattn print("Sage Attn is installed!") except: print("Sage Attn is not installed!") sageattn_varlen = None sageattn = None import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # region diffusers # copied from diffusers with some modifications to minimize dependencies # original code: https://github.com/huggingface/diffusers/ # original license: Apache-2.0 ACT2CLS = { "swish": nn.SiLU, "silu": nn.SiLU, "mish": nn.Mish, "gelu": nn.GELU, "relu": nn.ReLU, } def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. Args: act_fn (str): Name of activation function. Returns: nn.Module: Activation function. """ act_fn = act_fn.lower() if act_fn in ACT2CLS: return ACT2CLS[act_fn]() else: raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.scale = scale def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) return t_emb class FP32SiLU(nn.Module): r""" SiLU activation function with input upcasted to torch.float32. """ def __init__(self): super().__init__() def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.silu(inputs.float(), inplace=False).to(inputs.dtype) class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: # if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): # # fp16 gelu not supported on mps before torch 2.0 # return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): super().__init__() if out_features is None: out_features = hidden_size self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() elif act_fn == "silu_fp32": self.act_1 = FP32SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class LayerNormFramePack(nn.LayerNorm): # casting to dtype of input tensor is added def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) class FP32LayerNormFramePack(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: origin_dtype = x.dtype return torch.nn.functional.layer_norm( x.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ).to(origin_dtype) class RMSNormFramePack(nn.Module): r""" RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al. Args: dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True. eps (`float`): Small value to use when calculating the reciprocal of the square-root. elementwise_affine (`bool`, defaults to `True`): Boolean flag to denote if affine transformation should be applied. bias (`bool`, defaults to False): If also training the `bias` param. """ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): super().__init__() self.eps = eps self.elementwise_affine = elementwise_affine if isinstance(dim, numbers.Integral): dim = (dim,) self.dim = torch.Size(dim) self.weight = None self.bias = None if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) if bias: self.bias = nn.Parameter(torch.zeros(dim)) def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight is None: return hidden_states.to(input_dtype) return hidden_states.to(input_dtype) * self.weight.to(input_dtype) class AdaLayerNormContinuousFramePack(nn.Module): r""" Adaptive normalization layer with a norm layer (layer_norm or rms_norm). Args: embedding_dim (`int`): Embedding dimension to use during projection. conditioning_embedding_dim (`int`): Dimension of the input condition. elementwise_affine (`bool`, defaults to `True`): Boolean flag to denote if affine transformation should be applied. eps (`float`, defaults to 1e-5): Epsilon factor. bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. norm_type (`str`, defaults to `"layer_norm"`): Normalization layer to use. Values supported: "layer_norm", "rms_norm". """ def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": self.norm = RMSNormFramePack(embedding_dim, eps, elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x, conditioning_embedding): emb = self.linear(self.silu(conditioning_embedding)) scale, shift = emb.chunk(2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x class LinearActivation(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.activation = get_activation(activation) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) return self.activation(hidden_states) class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, inner_dim=None, bias: bool = True, ): super().__init__() if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim # if activation_fn == "gelu": # act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) # elif activation_fn == "geglu": # act_fn = GEGLU(dim, inner_dim, bias=bias) # elif activation_fn == "geglu-approximate": # act_fn = ApproximateGELU(dim, inner_dim, bias=bias) # elif activation_fn == "swiglu": # act_fn = SwiGLU(dim, inner_dim, bias=bias) elif activation_fn == "linear-silu": act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") else: raise ValueError(f"Unknown activation function: {activation_fn}") self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." # deprecate("scale", "1.0.0", deprecation_message) raise ValueError("scale is not supported in this version. Please remove it.") for module in self.net: hidden_states = module(hidden_states) return hidden_states # @maybe_allow_in_graph class Attention(nn.Module): r""" Minimal copy of Attention class from diffusers. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, bias: bool = False, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, eps: float = 1e-5, processor: Optional[any] = None, out_dim: int = None, context_pre_only=None, pre_only=False, ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim # if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim self.use_bias = bias self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only self.scale = dim_head**-0.5 self.heads = out_dim // dim_head if out_dim is not None else heads self.added_kv_proj_dim = added_kv_proj_dim if qk_norm is None: self.norm_q = None self.norm_k = None elif qk_norm == "rms_norm": self.norm_q = RMSNormFramePack(dim_head, eps=eps) self.norm_k = RMSNormFramePack(dim_head, eps=eps) else: raise ValueError( f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." ) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.added_proj_bias = True # added_proj_bias if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) else: self.add_q_proj = None self.add_k_proj = None self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=True)) # self.to_out.append(nn.Dropout(dropout)) self.to_out.append(nn.Identity()) # dropout=0.0 else: self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=True) else: self.to_add_out = None if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "rms_norm": self.norm_added_q = RMSNormFramePack(dim_head, eps=eps) self.norm_added_k = RMSNormFramePack(dim_head, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`") else: self.norm_added_q = None self.norm_added_k = None # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention if processor is None: processor = AttnProcessor2_0() self.set_processor(processor) def set_processor(self, processor: any) -> None: self.processor = processor def get_processor(self) -> any: return self.processor def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **cross_attention_kwargs, ) -> torch.Tensor: return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def prepare_attention_mask( self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. Args: attention_mask (`torch.Tensor`): The attention mask to prepare. target_length (`int`): The target length of the attention mask. This is the length of the attention mask after padding. batch_size (`int`): The batch size, which is used to repeat the attention mask. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the attention mask. Can be either `3` or `4`. Returns: `torch.Tensor`: The prepared attention mask. """ head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0, output_size=attention_mask.shape[0] * head_size) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1, output_size=attention_mask.shape[1] * head_size) return attention_mask class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) query_dtype = query.dtype # store dtype before potentially deleting query if encoder_hidden_states is None: encoder_hidden_states = hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) del query, key, value, attention_mask # free memory hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query_dtype) # use stored dtype # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states # endregion diffusers def pad_for_3d_conv(x, kernel_size): b, c, t, h, w = x.shape pt, ph, pw = kernel_size pad_t = (pt - (t % pt)) % pt pad_h = (ph - (h % ph)) % ph pad_w = (pw - (w % pw)) % pw return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") def center_down_sample_3d(x, kernel_size): # pt, ph, pw = kernel_size # cp = (pt * ph * pw) // 2 # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) # xc = xp[cp] # return xc return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) def get_cu_seqlens(text_mask, img_len): batch_size = text_mask.shape[0] text_len = text_mask.sum(dim=1) max_len = text_mask.shape[1] + img_len cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) # ensure device match for i in range(batch_size): s = text_len[i] + img_len s1 = i * max_len + s s2 = (i + 1) * max_len cu_seqlens[2 * i + 1] = s1 cu_seqlens[2 * i + 2] = s2 return cu_seqlens def apply_rotary_emb_transposed(x, freqs_cis): cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) del freqs_cis x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) del x_real, x_imag return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=None, split_attn=False): if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: if attn_mode == "sageattn" or attn_mode is None and sageattn is not None: x = sageattn(q, k, v, tensor_layout="NHD") return x if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None: x = flash_attn_func(q, k, v) return x if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None: x = xformers_attn_func(q, k, v) return x x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose( 1, 2 ) return x if split_attn: if attn_mode == "sageattn" or attn_mode is None and sageattn is not None: x = torch.empty_like(q) for i in range(q.size(0)): x[i : i + 1] = sageattn(q[i : i + 1], k[i : i + 1], v[i : i + 1], tensor_layout="NHD") return x if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None: x = torch.empty_like(q) for i in range(q.size(0)): x[i : i + 1] = flash_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1]) return x if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None: x = torch.empty_like(q) for i in range(q.size(0)): x[i : i + 1] = xformers_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1]) return x q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) x = torch.empty_like(q) for i in range(q.size(0)): x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(q[i : i + 1], k[i : i + 1], v[i : i + 1]) x = x.transpose(1, 2) return x batch_size = q.shape[0] q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) if attn_mode == "sageattn" or attn_mode is None and sageattn_varlen is not None: x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) del q, k, v # free memory elif attn_mode == "flash" or attn_mode is None and flash_attn_varlen_func is not None: x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) del q, k, v # free memory else: raise NotImplementedError("No Attn Installed or batch_size > 1 is not supported in this configuration. Try `--split_attn`.") x = x.view(batch_size, max_seqlen_q, *x.shape[2:]) return x class HunyuanAttnProcessorFlashAttnDouble: def __call__( self, attn: Attention, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, attn_mode: Optional[str] = None, split_attn: Optional[bool] = False, ): cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask # Project image latents query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) del hidden_states # free memory query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) query = apply_rotary_emb_transposed(query, image_rotary_emb) key = apply_rotary_emb_transposed(key, image_rotary_emb) del image_rotary_emb # free memory # Project context (text/encoder) embeddings encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) txt_length = encoder_hidden_states.shape[1] # store length before deleting del encoder_hidden_states # free memory encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) # Concatenate image and context q, k, v query = torch.cat([query, encoder_query], dim=1) key = torch.cat([key, encoder_key], dim=1) value = torch.cat([value, encoder_value], dim=1) del encoder_query, encoder_key, encoder_value # free memory hidden_states_attn = attn_varlen_func( query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn ) del query, key, value # free memory hidden_states_attn = hidden_states_attn.flatten(-2) hidden_states, encoder_hidden_states = hidden_states_attn[:, :-txt_length], hidden_states_attn[:, -txt_length:] del hidden_states_attn # free memory # Apply output projections hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) # Dropout/Identity encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states class HunyuanAttnProcessorFlashAttnSingle: def __call__( self, attn: Attention, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, attn_mode: Optional[str] = None, split_attn: Optional[bool] = False, ): cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask txt_length = encoder_hidden_states.shape[1] # Store text length # Concatenate image and context inputs hidden_states_cat = torch.cat([hidden_states, encoder_hidden_states], dim=1) del hidden_states, encoder_hidden_states # free memory # Project concatenated inputs query = attn.to_q(hidden_states_cat) key = attn.to_k(hidden_states_cat) value = attn.to_v(hidden_states_cat) del hidden_states_cat # free memory query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) del image_rotary_emb # free memory hidden_states = attn_varlen_func( query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn ) del query, key, value # free memory hidden_states = hidden_states.flatten(-2) hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] return hidden_states, encoder_hidden_states class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, guidance, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) time_guidance_emb = timesteps_emb + guidance_emb pooled_projections = self.text_embedder(pooled_projection) conditioning = time_guidance_emb + pooled_projections return conditioning class CombinedTimestepTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning class HunyuanVideoAdaNorm(nn.Module): def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: super().__init__() out_features = out_features or 2 * in_features self.linear = nn.Linear(in_features, out_features) self.nonlinearity = nn.SiLU() def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: temb = self.linear(self.nonlinearity(temb)) gate_msa, gate_mlp = temb.chunk(2, dim=-1) gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) return gate_msa, gate_mlp class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, bias=attention_bias, ) self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: norm_hidden_states = self.norm1(hidden_states) # Self-attention attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, ) del norm_hidden_states # free memory gate_msa, gate_mlp = self.norm_out(temb) hidden_states = hidden_states + attn_output * gate_msa del attn_output, gate_msa # free memory ff_output = self.ff(self.norm2(hidden_states)) hidden_states = hidden_states + ff_output * gate_mlp del ff_output, gate_mlp # free memory return hidden_states class HunyuanVideoIndividualTokenRefiner(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, num_layers: int, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() self.refiner_blocks = nn.ModuleList( [ HunyuanVideoIndividualTokenRefinerBlock( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, attention_bias=attention_bias, ) for _ in range(num_layers) ] ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: self_attn_mask = None if attention_mask is not None: batch_size = attention_mask.shape[0] seq_len = attention_mask.shape[1] attention_mask = attention_mask.to(hidden_states.device).bool() self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() self_attn_mask[:, :, :, 0] = True for block in self.refiner_blocks: hidden_states = block(hidden_states, temb, self_attn_mask) return hidden_states class HunyuanVideoTokenRefiner(nn.Module): def __init__( self, in_channels: int, num_attention_heads: int, attention_head_dim: int, num_layers: int, mlp_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.time_text_embed = CombinedTimestepTextProjEmbeddings(embedding_dim=hidden_size, pooled_projection_dim=in_channels) self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) self.token_refiner = HunyuanVideoIndividualTokenRefiner( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, num_layers=num_layers, mlp_width_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, attention_bias=attention_bias, ) def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) else: original_dtype = hidden_states.dtype mask_float = attention_mask.float().unsqueeze(-1) pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) temb = self.time_text_embed(timestep, pooled_projections) del pooled_projections # free memory hidden_states = self.proj_in(hidden_states) hidden_states = self.token_refiner(hidden_states, temb, attention_mask) del temb, attention_mask # free memory return hidden_states class HunyuanVideoRotaryPosEmbed(nn.Module): def __init__(self, rope_dim, theta): super().__init__() self.DT, self.DY, self.DX = rope_dim self.theta = theta @torch.no_grad() def get_frequency(self, dim, pos): T, H, W = pos.shape freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) return freqs.cos(), freqs.sin() @torch.no_grad() def forward_inner(self, frame_indices, height, width, device): GT, GY, GX = torch.meshgrid( frame_indices.to(device=device, dtype=torch.float32), torch.arange(0, height, device=device, dtype=torch.float32), torch.arange(0, width, device=device, dtype=torch.float32), indexing="ij", ) FCT, FST = self.get_frequency(self.DT, GT) del GT # free memory FCY, FSY = self.get_frequency(self.DY, GY) del GY # free memory FCX, FSX = self.get_frequency(self.DX, GX) del GX # free memory result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) del FCT, FCY, FCX, FST, FSY, FSX # free memory # Return result already on the correct device return result # Shape (2 * total_dim / 2, T, H, W) -> (total_dim, T, H, W) @torch.no_grad() def forward(self, frame_indices, height, width, device): frame_indices = frame_indices.unbind(0) results = [self.forward_inner(f, height, width, device) for f in frame_indices] results = torch.stack(results, dim=0) return results class AdaLayerNormZero(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError(f"unknown norm_type {norm_type}") def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError(f"unknown norm_type {norm_type}") def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa class AdaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias) else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) scale, shift = emb.chunk(2, dim=-1) del emb # free memory x = self.norm(x) * (1 + scale) + shift return x class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", attn_mode: Optional[str] = None, split_attn: Optional[bool] = False, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim mlp_dim = int(hidden_size * mlp_ratio) self.attn_mode = attn_mode self.split_attn = split_attn # Attention layer (pre_only=True means no output projection in Attention module itself) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=hidden_size, bias=True, processor=HunyuanAttnProcessorFlashAttnSingle(), qk_norm=qk_norm, eps=1e-6, pre_only=True, # Crucial: Attn processor will return raw attention output ) self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") self.proj_mlp = nn.Linear(hidden_size, mlp_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) del encoder_hidden_states # free memory residual = hidden_states # 1. Input normalization norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) norm_hidden_states, norm_encoder_hidden_states = ( norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :], ) # 2. Attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, attn_mode=self.attn_mode, split_attn=self.split_attn, ) attn_output = torch.cat([attn_output, context_attn_output], dim=1) del norm_hidden_states, norm_encoder_hidden_states, context_attn_output # free memory del image_rotary_emb # 3. Modulation and residual connection hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) del attn_output, mlp_hidden_states # free memory hidden_states = gate * self.proj_out(hidden_states) hidden_states = hidden_states + residual hidden_states, encoder_hidden_states = ( hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :], ) return hidden_states, encoder_hidden_states class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float, qk_norm: str = "rms_norm", attn_mode: Optional[str] = None, split_attn: Optional[bool] = False, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.attn_mode = attn_mode self.split_attn = split_attn self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, added_kv_proj_dim=hidden_size, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=hidden_size, context_pre_only=False, bias=True, processor=HunyuanAttnProcessorFlashAttnDouble(), qk_norm=qk_norm, eps=1e-6, ) self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") self.norm2_context = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # 2. Joint attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=freqs_cis, attn_mode=self.attn_mode, split_attn=self.split_attn, ) del norm_hidden_states, norm_encoder_hidden_states, freqs_cis # free memory # 3. Modulation and residual connection hidden_states = hidden_states + attn_output * gate_msa del attn_output, gate_msa # free memory encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa del context_attn_output, c_gate_msa # free memory norm_hidden_states = self.norm2(hidden_states) norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp del shift_mlp, scale_mlp # free memory norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp del c_shift_mlp, c_scale_mlp # free memory # 4. Feed-forward ff_output = self.ff(norm_hidden_states) del norm_hidden_states # free memory context_ff_output = self.ff_context(norm_encoder_hidden_states) del norm_encoder_hidden_states # free memory hidden_states = hidden_states + gate_mlp * ff_output del ff_output, gate_mlp # free memory encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output del context_ff_output, c_gate_mlp # free memory return hidden_states, encoder_hidden_states class ClipVisionProjection(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.Linear(in_channels, out_channels * 3) self.down = nn.Linear(out_channels * 3, out_channels) def forward(self, x): projected_x = self.down(nn.functional.silu(self.up(x))) return projected_x class HunyuanVideoPatchEmbed(nn.Module): def __init__(self, patch_size, in_chans, embed_dim): super().__init__() self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): def __init__(self, inner_dim): super().__init__() self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) @torch.no_grad() def initialize_weight_from_another_conv3d(self, another_layer): weight = another_layer.weight.detach().clone() bias = another_layer.bias.detach().clone() sd = { "proj.weight": weight.clone(), "proj.bias": bias.clone(), "proj_2x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=2, hk=2, wk=2) / 8.0, "proj_2x.bias": bias.clone(), "proj_4x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=4, hk=4, wk=4) / 64.0, "proj_4x.bias": bias.clone(), } sd = {k: v.clone() for k, v in sd.items()} self.load_state_dict(sd) return class HunyuanVideoTransformer3DModelPacked(nn.Module): # (PreTrainedModelMixin, GenerationMixin, # ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): # @register_to_config def __init__( self, in_channels: int = 16, out_channels: int = 16, num_attention_heads: int = 24, attention_head_dim: int = 128, num_layers: int = 20, num_single_layers: int = 40, num_refiner_layers: int = 2, mlp_ratio: float = 4.0, patch_size: int = 2, patch_size_t: int = 1, qk_norm: str = "rms_norm", guidance_embeds: bool = True, text_embed_dim: int = 4096, pooled_projection_dim: int = 768, rope_theta: float = 256.0, rope_axes_dim: Tuple[int] = (16, 56, 56), has_image_proj=False, image_proj_dim=1152, has_clean_x_embedder=False, attn_mode: Optional[str] = None, split_attn: Optional[bool] = False, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.config_patch_size = patch_size self.config_patch_size_t = patch_size_t # 1. Latent and condition embedders self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) self.clean_x_embedder = None self.image_projection = None # 2. RoPE self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( [ HunyuanVideoTransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, attn_mode=attn_mode, split_attn=split_attn, ) for _ in range(num_layers) ] ) # 4. Single stream transformer blocks self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleTransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, attn_mode=attn_mode, split_attn=split_attn, ) for _ in range(num_single_layers) ] ) # 5. Output projection self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) self.inner_dim = inner_dim self.use_gradient_checkpointing = False self.enable_teacache = False # if has_image_proj: # self.install_image_projection(image_proj_dim) self.image_projection = ClipVisionProjection(in_channels=image_proj_dim, out_channels=self.inner_dim) # self.config["has_image_proj"] = True # self.config["image_proj_dim"] = in_channels # if has_clean_x_embedder: # self.install_clean_x_embedder() self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) # self.config["has_clean_x_embedder"] = True self.high_quality_fp32_output_for_inference = True # False # change default to True # Block swapping attributes (initialized to None) self.blocks_to_swap = None self.offloader_double = None self.offloader_single = None @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype def enable_gradient_checkpointing(self): self.use_gradient_checkpointing = True print("Gradient checkpointing enabled for HunyuanVideoTransformer3DModelPacked.") # Logging def disable_gradient_checkpointing(self): self.use_gradient_checkpointing = False print("Gradient checkpointing disabled for HunyuanVideoTransformer3DModelPacked.") # Logging def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): self.enable_teacache = enable_teacache self.cnt = 0 self.num_steps = num_steps self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.previous_residual = None self.teacache_rescale_func = np.poly1d([7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]) if enable_teacache: print(f"TeaCache enabled: num_steps={num_steps}, rel_l1_thresh={rel_l1_thresh}") else: print("TeaCache disabled.") def gradient_checkpointing_method(self, block, *args): if self.use_gradient_checkpointing: result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) else: result = block(*args) return result def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool): self.blocks_to_swap = num_blocks self.num_double_blocks = len(self.transformer_blocks) self.num_single_blocks = len(self.single_transformer_blocks) double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1 assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, ( f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. " f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." ) self.offloader_double = ModelOffloader( "double", self.transformer_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device, # debug=True # Optional debugging ) self.offloader_single = ModelOffloader( "single", self.single_transformer_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device, # , debug=True ) print( f"HunyuanVideoTransformer3DModelPacked: Block swap enabled. Swapping {num_blocks} blocks, " + f"double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}, supports_backward: {supports_backward}." ) def switch_block_swap_for_inference(self): if self.blocks_to_swap and self.blocks_to_swap > 0: self.offloader_double.set_forward_only(True) self.offloader_single.set_forward_only(True) self.prepare_block_swap_before_forward() print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward only.") def switch_block_swap_for_training(self): if self.blocks_to_swap and self.blocks_to_swap > 0: self.offloader_double.set_forward_only(False) self.offloader_single.set_forward_only(False) self.prepare_block_swap_before_forward() print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward and backward.") def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: saved_double_blocks = self.transformer_blocks saved_single_blocks = self.single_transformer_blocks self.transformer_blocks = None self.single_transformer_blocks = None self.to(device) if self.blocks_to_swap: self.transformer_blocks = saved_double_blocks self.single_transformer_blocks = saved_single_blocks def prepare_block_swap_before_forward(self): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader_double.prepare_block_devices_before_forward(self.transformer_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_transformer_blocks) def process_input_hidden_states( self, latents, latent_indices=None, clean_latents=None, clean_latent_indices=None, clean_latents_2x=None, clean_latent_2x_indices=None, clean_latents_4x=None, clean_latent_4x_indices=None, ): hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) B, C, T, H, W = hidden_states.shape if latent_indices is None: latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) hidden_states = hidden_states.flatten(2).transpose(1, 2) rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) rope_freqs = rope_freqs.flatten(2).transpose(1, 2) if clean_latents is not None and clean_latent_indices is not None: clean_latents = clean_latents.to(hidden_states) clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) clean_latents = clean_latents.flatten(2).transpose(1, 2) clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) if clean_latents_2x is not None and clean_latent_2x_indices is not None: clean_latents_2x = clean_latents_2x.to(hidden_states) clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) clean_latent_2x_rope_freqs = self.rope( frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device ) clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) if clean_latents_4x is not None and clean_latent_4x_indices is not None: clean_latents_4x = clean_latents_4x.to(hidden_states) clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) clean_latent_4x_rope_freqs = self.rope( frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device ) clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) return hidden_states, rope_freqs def forward( self, hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance, latent_indices=None, clean_latents=None, clean_latent_indices=None, clean_latents_2x=None, clean_latent_2x_indices=None, clean_latents_4x=None, clean_latent_4x_indices=None, image_embeddings=None, attention_kwargs=None, return_dict=True, ): if attention_kwargs is None: attention_kwargs = {} batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config_patch_size, self.config_patch_size_t post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p original_context_length = post_patch_num_frames * post_patch_height * post_patch_width hidden_states, rope_freqs = self.process_input_hidden_states( hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices, ) del ( latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices, ) # free memory temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) encoder_hidden_states = self.gradient_checkpointing_method( self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask ) if self.image_projection is not None: assert image_embeddings is not None, "You must use image embeddings!" extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) extra_attention_mask = torch.ones( (batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device, ) # must cat before (not after) encoder_hidden_states, due to attn masking encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) del extra_encoder_hidden_states, extra_attention_mask # free memory with torch.no_grad(): if batch_size == 1: # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want # If they are not same, then their impls are wrong. Ours are always the correct one. text_len = encoder_attention_mask.sum().item() encoder_hidden_states = encoder_hidden_states[:, :text_len] attention_mask = None, None, None, None else: img_seq_len = hidden_states.shape[1] txt_seq_len = encoder_hidden_states.shape[1] cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) cu_seqlens_kv = cu_seqlens_q max_seqlen_q = img_seq_len + txt_seq_len max_seqlen_kv = max_seqlen_q attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv del cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv # free memory del encoder_attention_mask # free memory if self.enable_teacache: modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] if self.cnt == 0 or self.cnt == self.num_steps - 1: should_calc = True self.accumulated_rel_l1_distance = 0 else: # Ensure both tensors are on the same device before comparison prev_input = self.previous_modulated_input.to(modulated_inp.device) curr_rel_l1 = ( ((modulated_inp - prev_input).abs().mean() / prev_input.abs().mean()) .cpu() .item() ) self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh if should_calc: self.accumulated_rel_l1_distance = 0 # Explicitly store the tensor on the current device self.previous_modulated_input = modulated_inp.detach().clone() self.cnt += 1 if self.cnt == self.num_steps: self.cnt = 0 if not should_calc: # Ensure residual is on the same device as hidden_states hidden_states = hidden_states + self.previous_residual.to(hidden_states.device) else: ori_hidden_states = hidden_states.clone() # --- BEFORE --- # for block_id, block in enumerate(self.transformer_blocks): # hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( # block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs # ) # # for block_id, block in enumerate(self.single_transformer_blocks): # hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( # block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs # ) # --- AFTER --- for block_id, block in enumerate(self.transformer_blocks): if self.blocks_to_swap: # Add block swap logic here self.offloader_double.wait_for_block(block_id) hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) if self.blocks_to_swap: # Add block swap logic here self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id) for block_id, block in enumerate(self.single_transformer_blocks): if self.blocks_to_swap: # Add block swap logic here self.offloader_single.wait_for_block(block_id) hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) if self.blocks_to_swap: # Add block swap logic here self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id) # --- END MODIFICATION --- # Store residual on the same device self.previous_residual = (hidden_states - ori_hidden_states).detach().clone() del ori_hidden_states else: for block_id, block in enumerate(self.transformer_blocks): if self.blocks_to_swap: self.offloader_double.wait_for_block(block_id) hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) if self.blocks_to_swap: self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id) for block_id, block in enumerate(self.single_transformer_blocks): if self.blocks_to_swap: self.offloader_single.wait_for_block(block_id) hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) if self.blocks_to_swap: self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id) del attention_mask, rope_freqs # free memory del encoder_hidden_states # free memory hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) hidden_states = hidden_states[:, -original_context_length:, :] if self.high_quality_fp32_output_for_inference: hidden_states = hidden_states.to(dtype=torch.float32) if self.proj_out.weight.dtype != torch.float32: self.proj_out.to(dtype=torch.float32) hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) hidden_states = einops.rearrange( hidden_states, "b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)", t=post_patch_num_frames, h=post_patch_height, w=post_patch_width, pt=p_t, ph=p, pw=p, ) if return_dict: # return Transformer2DModelOutput(sample=hidden_states) return SimpleNamespace(sample=hidden_states) return (hidden_states,) def fp8_optimization( self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False ) -> dict[str, torch.Tensor]: # Return type hint added """ Optimize the model state_dict with fp8. Args: state_dict (dict[str, torch.Tensor]): The state_dict of the model. device (torch.device): The device to calculate the weight. move_to_device (bool): Whether to move the weight to the device after optimization. use_scaled_mm (bool): Whether to use scaled matrix multiplication for FP8. """ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"] EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8 # inplace optimization state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device) # apply monkey patching apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm) return state_dict def create_hunyuan_video_transformer_3d_model(attn_mode: str, split_attn: bool = False) -> HunyuanVideoTransformer3DModelPacked: with init_empty_weights(): logger.info(f"Creating HunyuanVideoTransformer3DModelPacked") model = HunyuanVideoTransformer3DModelPacked( attention_head_dim=128, guidance_embeds=True, has_clean_x_embedder=True, has_image_proj=True, image_proj_dim=1152, in_channels=16, mlp_ratio=4.0, num_attention_heads=24, num_layers=20, num_refiner_layers=2, num_single_layers=40, out_channels=16, patch_size=2, patch_size_t=1, pooled_projection_dim=768, qk_norm="rms_norm", rope_axes_dim=(16, 56, 56), rope_theta=256.0, text_embed_dim=4096, attn_mode=attn_mode, split_attn=split_attn, ) return model def load_packed_model( device: Union[str, torch.device], dit_path: str, attn_mode: str, loading_device: Union[str, torch.device], fp8_scaled: bool = False, split_attn: bool = False, ) -> HunyuanVideoTransformer3DModelPacked: # TODO support split_attn device = torch.device(device) loading_device = torch.device(loading_device) if os.path.isdir(dit_path): # we don't support from_pretrained for now, so loading safetensors directly safetensor_files = glob.glob(os.path.join(dit_path, "*.safetensors")) if len(safetensor_files) == 0: raise ValueError(f"Cannot find safetensors file in {dit_path}") # sort by name and take the first one safetensor_files.sort() dit_path = safetensor_files[0] model = create_hunyuan_video_transformer_3d_model(attn_mode, split_attn=split_attn) # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others) dit_loading_device = torch.device("cpu") if fp8_scaled else loading_device logger.info(f"Loading DiT model from {dit_path}, device={dit_loading_device}") # load model weights with the specified dtype or as is sd = load_split_weights(dit_path, device=dit_loading_device, disable_mmap=True) if fp8_scaled: # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap) logger.info(f"Optimizing model weights to fp8. This may take a while.") sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu") if loading_device.type != "cpu": # make sure all the model weights are on the loading_device logger.info(f"Moving weights to {loading_device}") for key in sd.keys(): sd[key] = sd[key].to(loading_device) info = model.load_state_dict(sd, strict=True, assign=True) logger.info(f"Loaded DiT model from {dit_path}, info={info}") return model