Spaces:
Running
Running
import os | |
import logging | |
from types import SimpleNamespace | |
from typing import Optional, Union | |
import accelerate | |
from accelerate import Accelerator, init_empty_weights | |
import torch | |
from safetensors.torch import load_file | |
from transformers import ( | |
LlamaTokenizerFast, | |
LlamaConfig, | |
LlamaModel, | |
CLIPTokenizer, | |
CLIPTextModel, | |
CLIPConfig, | |
SiglipImageProcessor, | |
SiglipVisionModel, | |
SiglipVisionConfig, | |
) | |
from utils.safetensors_utils import load_split_weights | |
from hunyuan_model.vae import load_vae as hunyuan_load_vae | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def load_vae( | |
vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device] | |
): | |
# single file and directory (contains 'vae') support | |
if os.path.isdir(vae_path): | |
vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors") | |
else: | |
vae_path = vae_path | |
vae_dtype = torch.float16 # if vae_dtype is None else str_to_dtype(vae_dtype) | |
vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path) | |
vae.eval() | |
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} | |
# set chunk_size to CausalConv3d recursively | |
chunk_size = vae_chunk_size | |
if chunk_size is not None: | |
vae.set_chunk_size_for_causal_conv_3d(chunk_size) | |
logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") | |
if vae_spatial_tile_sample_min_size is not None: | |
vae.enable_spatial_tiling(True) | |
vae.tile_sample_min_size = vae_spatial_tile_sample_min_size | |
vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8 | |
logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}") | |
# elif vae_tiling: | |
else: | |
vae.enable_spatial_tiling(True) | |
return vae | |
# region Text Encoders | |
# Text Encoder configs are copied from HunyuanVideo repo | |
LLAMA_CONFIG = { | |
"architectures": ["LlamaModel"], | |
"attention_bias": False, | |
"attention_dropout": 0.0, | |
"bos_token_id": 128000, | |
"eos_token_id": 128001, | |
"head_dim": 128, | |
"hidden_act": "silu", | |
"hidden_size": 4096, | |
"initializer_range": 0.02, | |
"intermediate_size": 14336, | |
"max_position_embeddings": 8192, | |
"mlp_bias": False, | |
"model_type": "llama", | |
"num_attention_heads": 32, | |
"num_hidden_layers": 32, | |
"num_key_value_heads": 8, | |
"pretraining_tp": 1, | |
"rms_norm_eps": 1e-05, | |
"rope_scaling": None, | |
"rope_theta": 500000.0, | |
"tie_word_embeddings": False, | |
"torch_dtype": "float16", | |
"transformers_version": "4.46.3", | |
"use_cache": True, | |
"vocab_size": 128320, | |
} | |
CLIP_CONFIG = { | |
# "_name_or_path": "/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2", | |
"architectures": ["CLIPTextModel"], | |
"attention_dropout": 0.0, | |
"bos_token_id": 0, | |
"dropout": 0.0, | |
"eos_token_id": 2, | |
"hidden_act": "quick_gelu", | |
"hidden_size": 768, | |
"initializer_factor": 1.0, | |
"initializer_range": 0.02, | |
"intermediate_size": 3072, | |
"layer_norm_eps": 1e-05, | |
"max_position_embeddings": 77, | |
"model_type": "clip_text_model", | |
"num_attention_heads": 12, | |
"num_hidden_layers": 12, | |
"pad_token_id": 1, | |
"projection_dim": 768, | |
"torch_dtype": "float16", | |
"transformers_version": "4.48.0.dev0", | |
"vocab_size": 49408, | |
} | |
def load_text_encoder1( | |
args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None | |
) -> tuple[LlamaTokenizerFast, LlamaModel]: | |
# single file, split file and directory (contains 'text_encoder') support | |
logger.info(f"Loading text encoder 1 tokenizer") | |
tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer") | |
logger.info(f"Loading text encoder 1 from {args.text_encoder1}") | |
if os.path.isdir(args.text_encoder1): | |
# load from directory, configs are in the directory | |
text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16) | |
else: | |
# load from file, we create the model with the appropriate config | |
config = LlamaConfig(**LLAMA_CONFIG) | |
with init_empty_weights(): | |
text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16) | |
state_dict = load_split_weights(args.text_encoder1) | |
# support weights from ComfyUI | |
if "model.embed_tokens.weight" in state_dict: | |
for key in list(state_dict.keys()): | |
if key.startswith("model."): | |
new_key = key.replace("model.", "") | |
state_dict[new_key] = state_dict[key] | |
del state_dict[key] | |
if "tokenizer" in state_dict: | |
state_dict.pop("tokenizer") | |
if "lm_head.weight" in state_dict: | |
state_dict.pop("lm_head.weight") | |
# # support weights from ComfyUI | |
# if "tokenizer" in state_dict: | |
# state_dict.pop("tokenizer") | |
text_encoder1.load_state_dict(state_dict, strict=True, assign=True) | |
if fp8_llm: | |
org_dtype = text_encoder1.dtype | |
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") | |
text_encoder1.to(device=device, dtype=torch.float8_e4m3fn) | |
# prepare LLM for fp8 | |
def prepare_fp8(llama_model: LlamaModel, target_dtype): | |
def forward_hook(module): | |
def forward(hidden_states): | |
input_dtype = hidden_states.dtype | |
hidden_states = hidden_states.to(torch.float32) | |
variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) | |
return module.weight.to(input_dtype) * hidden_states.to(input_dtype) | |
return forward | |
for module in llama_model.modules(): | |
if module.__class__.__name__ in ["Embedding"]: | |
# print("set", module.__class__.__name__, "to", target_dtype) | |
module.to(target_dtype) | |
if module.__class__.__name__ in ["LlamaRMSNorm"]: | |
# print("set", module.__class__.__name__, "hooks") | |
module.forward = forward_hook(module) | |
prepare_fp8(text_encoder1, org_dtype) | |
else: | |
text_encoder1.to(device) | |
text_encoder1.eval() | |
return tokenizer1, text_encoder1 | |
def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]: | |
# single file and directory (contains 'text_encoder_2') support | |
logger.info(f"Loading text encoder 2 tokenizer") | |
tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2") | |
logger.info(f"Loading text encoder 2 from {args.text_encoder2}") | |
if os.path.isdir(args.text_encoder2): | |
# load from directory, configs are in the directory | |
text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16) | |
else: | |
# we only have one file, so we can load it directly | |
config = CLIPConfig(**CLIP_CONFIG) | |
with init_empty_weights(): | |
text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16) | |
state_dict = load_file(args.text_encoder2) | |
text_encoder2.load_state_dict(state_dict, strict=True, assign=True) | |
text_encoder2.eval() | |
return tokenizer2, text_encoder2 | |
# endregion | |
# region image encoder | |
# Siglip configs are copied from FramePack repo | |
FEATURE_EXTRACTOR_CONFIG = { | |
"do_convert_rgb": None, | |
"do_normalize": True, | |
"do_rescale": True, | |
"do_resize": True, | |
"image_mean": [0.5, 0.5, 0.5], | |
"image_processor_type": "SiglipImageProcessor", | |
"image_std": [0.5, 0.5, 0.5], | |
"processor_class": "SiglipProcessor", | |
"resample": 3, | |
"rescale_factor": 0.00392156862745098, | |
"size": {"height": 384, "width": 384}, | |
} | |
IMAGE_ENCODER_CONFIG = { | |
"_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder", | |
"architectures": ["SiglipVisionModel"], | |
"attention_dropout": 0.0, | |
"hidden_act": "gelu_pytorch_tanh", | |
"hidden_size": 1152, | |
"image_size": 384, | |
"intermediate_size": 4304, | |
"layer_norm_eps": 1e-06, | |
"model_type": "siglip_vision_model", | |
"num_attention_heads": 16, | |
"num_channels": 3, | |
"num_hidden_layers": 27, | |
"patch_size": 14, | |
"torch_dtype": "bfloat16", | |
"transformers_version": "4.46.2", | |
} | |
def load_image_encoders(args): | |
logger.info(f"Loading image encoder feature extractor") | |
feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG) | |
# single file, split file and directory (contains 'image_encoder') support | |
logger.info(f"Loading image encoder from {args.image_encoder}") | |
if os.path.isdir(args.image_encoder): | |
# load from directory, configs are in the directory | |
image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16) | |
else: | |
# load from file, we create the model with the appropriate config | |
config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG) | |
with init_empty_weights(): | |
image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16) | |
state_dict = load_file(args.image_encoder) | |
image_encoder.load_state_dict(state_dict, strict=True, assign=True) | |
image_encoder.eval() | |
return feature_extractor, image_encoder | |
# endregion | |