import os os.environ['HF_HOME'] = os.path.abspath( os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')) ) import gradio as gr import torch import traceback import einops import safetensors.torch as sf import numpy as np import math import spaces from PIL import Image # Diffusers models from diffusers import AutoencoderKLHunyuanVideo # Transformers models from transformers import ( LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, AutoImageProcessor, CLIPImageProcessor, CLIPVisionModel ) # Local helper modules from diffusers_helper.hunyuan import ( encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake ) from diffusers_helper.utils import ( save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp ) from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan from diffusers_helper.clip_vision import hf_clip_vision_encode from diffusers_helper.bucket_tools import find_nearest_bucket # Thread utilities from diffusers_helper.thread_utils import AsyncStream, async_run # Gradio progress bar utils from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html # Set device to CPU device = torch.device("cpu") # Load models text_encoder = LlamaModel.from_pretrained( "hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16 ).to(device) text_encoder_2 = CLIPTextModel.from_pretrained( "hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16 ).to(device) tokenizer = LlamaTokenizerFast.from_pretrained( "hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer' ) tokenizer_2 = CLIPTokenizer.from_pretrained( "hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2' ) vae = AutoencoderKLHunyuanVideo.from_pretrained( "hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16 ).to(device) # Use AutoImageProcessor instead of SiglipImageProcessor feature_extractor = CLIPImageProcessor.from_pretrained( "lllyasviel/flux_redux_bfl", subfolder='feature_extractor' ) image_encoder = CLIPVisionModel.from_pretrained( "lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16, ignore_mismatched_sizes=True ).to(device) # Make sure device is defined earlier as "cpu" transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained( 'lllyasviel/FramePack_F1_I2V_HY_20250503', torch_dtype=torch.bfloat16 ).to(device) # Evaluation mode vae.eval() text_encoder.eval() text_encoder_2.eval() image_encoder.eval() transformer.eval() # Move to correct dtype transformer.to(dtype=torch.bfloat16) vae.to(dtype=torch.float16) image_encoder.to(dtype=torch.float16) text_encoder.to(dtype=torch.float16) text_encoder_2.to(dtype=torch.float16) # No gradient vae.requires_grad_(False) text_encoder.requires_grad_(False) text_encoder_2.requires_grad_(False) image_encoder.requires_grad_(False) transformer.requires_grad_(False) stream = AsyncStream() outputs_folder = './outputs/' os.makedirs(outputs_folder, exist_ok=True) examples = [ ["img_examples/1.png", "The girl dances gracefully, with clear movements, full of charm."], ["img_examples/2.jpg", "The man dances flamboyantly, swinging his hips and striking bold poses with dramatic flair."], ["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."] ] def generate_examples(input_image, prompt): t2v=False n_prompt="" seed=31337 total_second_length=60 latent_window_size=9 steps=25 cfg=1.0 gs=10.0 rs=0.0 gpu_memory_preservation=6 # unused use_teacache=True mp4_crf=16 global stream if t2v: default_height, default_width = 640, 640 input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255 print("No input image provided. Using a blank white image.") yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True) stream = AsyncStream() async_run( worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf ) output_filename = None while True: flag, data = stream.output_queue.next() if flag == 'file': output_filename = data yield ( output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True) ) if flag == 'progress': preview, desc, html = data yield ( gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True) ) if flag == 'end': yield ( output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False) ) break @torch.no_grad() def worker( input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf ): total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) total_latent_sections = int(max(round(total_latent_sections), 1)) job_id = generate_timestamp() stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...')))) try: llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) if cfg == 1: llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler) else: llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512) H, W, C = input_image.shape height, width = find_nearest_bucket(H, W, resolution=640) input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height) Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png')) input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1 input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None] start_latent = vae_encode(input_image_pt, vae).to(device) image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder) image_encoder_last_hidden_state = image_encoder_output.last_hidden_state llama_vec = llama_vec.to(transformer.dtype).to(device) llama_vec_n = llama_vec_n.to(transformer.dtype).to(device) clip_l_pooler = clip_l_pooler.to(transformer.dtype).to(device) clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype).to(device) image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype).to(device) rnd = torch.Generator("cpu").manual_seed(seed) history_latents = torch.zeros( size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32 ).to(device) history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2) total_generated_latent_frames = 1 for section_index in range(total_latent_sections): if stream.input_queue.top() == 'end': stream.output_queue.push(('end', None)) return if use_teacache: transformer.initialize_teacache(enable_teacache=True, num_steps=steps) else: transformer.initialize_teacache(enable_teacache=False) def callback(d): preview = d['denoised'] preview = vae_decode_fake(preview) preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c') if stream.input_queue.top() == 'end': stream.output_queue.push(('end', None)) raise KeyboardInterrupt('User ends the task.') current_step = d['i'] + 1 percentage = int(100.0 * current_step / steps) hint = f'Sampling {current_step}/{steps}' desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}' stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint)))) return indices = torch.arange( 0, sum([1, 16, 2, 1, latent_window_size]) ).unsqueeze(0) ( clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1) clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[ :, :, -sum([16, 2, 1]):, :, : ].split([16, 2, 1], dim=2) clean_latents = torch.cat( [start_latent.to(history_latents), clean_latents_1x], dim=2 ) generated_latents = sample_hunyuan( transformer=transformer, sampler='unipc', width=width, height=height, frames=latent_window_size * 4 - 3, real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs, num_inference_steps=steps, generator=rnd, prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, device=device, dtype=torch.bfloat16, image_embeddings=image_encoder_last_hidden_state, latent_indices=latent_indices, clean_latents=clean_latents, clean_latent_indices=clean_latent_indices, clean_latents_2x=clean_latents_2x, clean_latent_2x_indices=clean_latent_2x_indices, clean_latents_4x=clean_latents_4x, clean_latent_4x_indices=clean_latent_4x_indices, callback=callback, ) total_generated_latent_frames += int(generated_latents.shape[2]) history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2) real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :] if history_pixels is None: history_pixels = vae_decode(real_history_latents, vae).cpu() else: section_latent_frames = latent_window_size * 2 overlapped_frames = latent_window_size * 4 - 3 current_pixels = vae_decode( real_history_latents[:, :, -section_latent_frames:], vae ).cpu() history_pixels = soft_append_bcthw( history_pixels, current_pixels, overlapped_frames ) output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4') save_bcthw_as_mp4(history_pixels, output_filename, fps=30) stream.output_queue.push(('file', output_filename)) except Exception as e: traceback.print_exc() stream.output_queue.push(('end', None)) return def get_duration( input_image, prompt, t2v, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, quality_radio=None, aspect_ratio=None ): # Accept extra arguments for compatibility with process() return total_second_length * 60 @spaces.GPU(duration=get_duration) def process( input_image, prompt, t2v=False, n_prompt="", seed=31337, total_second_length=60, latent_window_size=9, steps=25, cfg=1.0, gs=10.0, rs=0.0, gpu_memory_preservation=6, use_teacache=True, mp4_crf=16, quality_radio="640x360", aspect_ratio="1:1" ): global stream # Map quality options to actual resolutions quality_map = { "360p": (640, 360), "480p": (854, 480), "540p": (960, 540), "720p": (1280, 720), "640x360": (640, 360), # fallback } # Map aspect ratio strings to width/height ratios aspect_map = { "1:1": (1, 1), "3:4": (3, 4), "4:3": (4, 3), "16:9": (16, 9), "9:16": (9, 16), } # Get target resolution based on selected quality target_width, target_height = quality_map.get(quality_radio, (640, 360)) if t2v: ar_w, ar_h = aspect_map.get(aspect_ratio, (1, 1)) # Recalculate based on aspect ratio if ar_w >= ar_h: target_height = int(round(target_width * ar_h / ar_w)) else: target_width = int(round(target_height * ar_w / ar_h)) input_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255 print(f"Using blank white image for text-to-video mode, {target_width}x{target_height} ({aspect_ratio})") else: # Resize and crop input image to match selected resolution H, W, C = input_image.shape height, width = find_nearest_bucket(H, W, resolution=target_width) input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height) Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png')) yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True) stream = AsyncStream() async_run( worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf ) output_filename = None while True: flag, data = stream.output_queue.next() if flag == 'file': output_filename = data yield ( output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True) ) elif flag == 'progress': preview, desc, html = data yield ( gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True) ) elif flag == 'end': yield ( output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False) ) break def end_process(): stream.input_queue.push('end') quick_prompts = [ 'The girl dances gracefully, with clear movements, full of charm.', 'A character doing some simple body movements.' ] quick_prompts = [[x] for x in quick_prompts] def make_custom_css(): base_progress_css = make_progress_bar_css() extra_css = """ body { background: #1a1b1e !important; font-family: "Noto Sans", sans-serif; color: #e0e0e0; } #title-container { text-align: center; padding: 20px 0; margin-bottom: 30px; } #title-container h1 { color: #4b9ffa; font-size: 2.5rem; margin: 0; font-weight: 800; } #title-container p { color: #e0e0e0; } .three-column-container { display: flex; gap: 20px; min-height: 800px; max-width: 1600px; margin: 0 auto; } .settings-panel { flex: 0 0 150px; background: #2a2b2e; padding: 12px; border-radius: 15px; border: 1px solid #3a3b3e; } .settings-panel .gr-slider { width: calc(100% - 10px) !important; } .settings-panel label { color: #e0e0e0 !important; } .settings-panel label span:first-child { font-size: 0.9rem !important; } .main-panel { flex: 1; background: #2a2b2e; padding: 20px; border-radius: 15px; border: 1px solid #3a3b3e; display: flex; flex-direction: column; gap: 20px; } .output-panel { flex: 1; background: #2a2b2e; padding: 20px; border-radius: 15px; border: 1px solid #3a3b3e; display: flex; flex-direction: column; align-items: center; /* Center output content */ gap: 20px; } .output-panel > div { width: 100%; max-width: 640px; /* Limit width for better centering */ } .settings-panel h3 { color: #4b9ffa; margin-bottom: 15px; font-size: 1.1rem; border-bottom: 2px solid #4b9ffa; padding-bottom: 8px; } .prompt-container { min-height: 200px; } .quick-prompts { margin-top: 10px; padding: 10px; background: #1a1b1e; border-radius: 10px; } .button-container { display: flex; gap: 10px; margin: 15px 0; justify-content: center; width: 100%; } /* Override Gradio's default light theme */ .gr-box { background: #2a2b2e !important; border-color: #3a3b3e !important; } .gr-input, .gr-textbox { background: #1a1b1e !important; border-color: #3a3b3e !important; color: #e0e0e0 !important; } .gr-form { background: transparent !important; border: none !important; } .gr-label { color: #e0e0e0 !important; } .gr-button { background: #4b9ffa !important; color: white !important; } .gr-button.secondary-btn { background: #ff4d4d !important; } """ return base_progress_css + extra_css css = make_custom_css() block = gr.Blocks(css=css).queue() with block: with gr.Group(elem_id="title-container"): gr.Markdown("