aaa / app.py
A24005179's picture
Update app.py
f3b7ad6 verified
import os
os.environ['HF_HOME'] = os.path.abspath(
os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))
)
import gradio as gr
import torch
import traceback
import einops
import safetensors.torch as sf
import numpy as np
import math
import spaces
from PIL import Image
# Diffusers models
from diffusers import AutoencoderKLHunyuanVideo
# Transformers models
from transformers import (
LlamaModel, CLIPTextModel,
LlamaTokenizerFast, CLIPTokenizer,
AutoImageProcessor, CLIPImageProcessor, CLIPVisionModel
)
# Local helper modules
from diffusers_helper.hunyuan import (
encode_prompt_conds, vae_decode,
vae_encode, vae_decode_fake
)
from diffusers_helper.utils import (
save_bcthw_as_mp4, crop_or_pad_yield_mask,
soft_append_bcthw, resize_and_center_crop,
state_dict_weighted_merge, state_dict_offset_merge,
generate_timestamp
)
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket
# Thread utilities
from diffusers_helper.thread_utils import AsyncStream, async_run
# Gradio progress bar utils
from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
# Set device to CPU
device = torch.device("cpu")
# Load models
text_encoder = LlamaModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='text_encoder',
torch_dtype=torch.float16
).to(device)
text_encoder_2 = CLIPTextModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='text_encoder_2',
torch_dtype=torch.float16
).to(device)
tokenizer = LlamaTokenizerFast.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='tokenizer'
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='tokenizer_2'
)
vae = AutoencoderKLHunyuanVideo.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='vae',
torch_dtype=torch.float16
).to(device)
# Use AutoImageProcessor instead of SiglipImageProcessor
feature_extractor = CLIPImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl",
subfolder='feature_extractor'
)
image_encoder = CLIPVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl",
subfolder='image_encoder',
torch_dtype=torch.float16,
ignore_mismatched_sizes=True
).to(device) # Make sure device is defined earlier as "cpu"
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
'lllyasviel/FramePack_F1_I2V_HY_20250503',
torch_dtype=torch.bfloat16
).to(device)
# Evaluation mode
vae.eval()
text_encoder.eval()
text_encoder_2.eval()
image_encoder.eval()
transformer.eval()
# Move to correct dtype
transformer.to(dtype=torch.bfloat16)
vae.to(dtype=torch.float16)
image_encoder.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
text_encoder_2.to(dtype=torch.float16)
# No gradient
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
transformer.requires_grad_(False)
stream = AsyncStream()
outputs_folder = './outputs/'
os.makedirs(outputs_folder, exist_ok=True)
examples = [
["img_examples/1.png", "The girl dances gracefully, with clear movements, full of charm."],
["img_examples/2.jpg", "The man dances flamboyantly, swinging his hips and striking bold poses with dramatic flair."],
["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."]
]
def generate_examples(input_image, prompt):
t2v=False
n_prompt=""
seed=31337
total_second_length=60
latent_window_size=9
steps=25
cfg=1.0
gs=10.0
rs=0.0
gpu_memory_preservation=6 # unused
use_teacache=True
mp4_crf=16
global stream
if t2v:
default_height, default_width = 640, 640
input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
print("No input image provided. Using a blank white image.")
yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
stream = AsyncStream()
async_run(
worker, input_image, prompt, n_prompt, seed,
total_second_length, latent_window_size, steps,
cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
)
output_filename = None
while True:
flag, data = stream.output_queue.next()
if flag == 'file':
output_filename = data
yield (
output_filename,
gr.update(),
gr.update(),
gr.update(),
gr.update(interactive=False),
gr.update(interactive=True)
)
if flag == 'progress':
preview, desc, html = data
yield (
gr.update(),
gr.update(visible=True, value=preview),
desc,
html,
gr.update(interactive=False),
gr.update(interactive=True)
)
if flag == 'end':
yield (
output_filename,
gr.update(visible=False),
gr.update(),
'',
gr.update(interactive=True),
gr.update(interactive=False)
)
break
@torch.no_grad()
def worker(
input_image, prompt, n_prompt, seed,
total_second_length, latent_window_size, steps,
cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
):
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))
job_id = generate_timestamp()
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
try:
llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
if cfg == 1:
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
else:
llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
H, W, C = input_image.shape
height, width = find_nearest_bucket(H, W, resolution=640)
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
start_latent = vae_encode(input_image_pt, vae).to(device)
image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
llama_vec = llama_vec.to(transformer.dtype).to(device)
llama_vec_n = llama_vec_n.to(transformer.dtype).to(device)
clip_l_pooler = clip_l_pooler.to(transformer.dtype).to(device)
clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype).to(device)
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype).to(device)
rnd = torch.Generator("cpu").manual_seed(seed)
history_latents = torch.zeros(
size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
dtype=torch.float32
).to(device)
history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
total_generated_latent_frames = 1
for section_index in range(total_latent_sections):
if stream.input_queue.top() == 'end':
stream.output_queue.push(('end', None))
return
if use_teacache:
transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
else:
transformer.initialize_teacache(enable_teacache=False)
def callback(d):
preview = d['denoised']
preview = vae_decode_fake(preview)
preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
if stream.input_queue.top() == 'end':
stream.output_queue.push(('end', None))
raise KeyboardInterrupt('User ends the task.')
current_step = d['i'] + 1
percentage = int(100.0 * current_step / steps)
hint = f'Sampling {current_step}/{steps}'
desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}'
stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
return
indices = torch.arange(
0, sum([1, 16, 2, 1, latent_window_size])
).unsqueeze(0)
(
clean_latent_indices_start,
clean_latent_4x_indices,
clean_latent_2x_indices,
clean_latent_1x_indices,
latent_indices
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
:, :, -sum([16, 2, 1]):, :, :
].split([16, 2, 1], dim=2)
clean_latents = torch.cat(
[start_latent.to(history_latents), clean_latents_1x],
dim=2
)
generated_latents = sample_hunyuan(
transformer=transformer,
sampler='unipc',
width=width,
height=height,
frames=latent_window_size * 4 - 3,
real_guidance_scale=cfg,
distilled_guidance_scale=gs,
guidance_rescale=rs,
num_inference_steps=steps,
generator=rnd,
prompt_embeds=llama_vec,
prompt_embeds_mask=llama_attention_mask,
prompt_poolers=clip_l_pooler,
negative_prompt_embeds=llama_vec_n,
negative_prompt_embeds_mask=llama_attention_mask_n,
negative_prompt_poolers=clip_l_pooler_n,
device=device,
dtype=torch.bfloat16,
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
callback=callback,
)
total_generated_latent_frames += int(generated_latents.shape[2])
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
if history_pixels is None:
history_pixels = vae_decode(real_history_latents, vae).cpu()
else:
section_latent_frames = latent_window_size * 2
overlapped_frames = latent_window_size * 4 - 3
current_pixels = vae_decode(
real_history_latents[:, :, -section_latent_frames:], vae
).cpu()
history_pixels = soft_append_bcthw(
history_pixels, current_pixels, overlapped_frames
)
output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
stream.output_queue.push(('file', output_filename))
except Exception as e:
traceback.print_exc()
stream.output_queue.push(('end', None))
return
def get_duration(
input_image, prompt, t2v, n_prompt,
seed, total_second_length, latent_window_size,
steps, cfg, gs, rs, gpu_memory_preservation,
use_teacache, mp4_crf, quality_radio=None, aspect_ratio=None
):
# Accept extra arguments for compatibility with process()
return total_second_length * 60
@spaces.GPU(duration=get_duration)
def process(
input_image, prompt, t2v=False, n_prompt="", seed=31337,
total_second_length=60, latent_window_size=9, steps=25,
cfg=1.0, gs=10.0, rs=0.0, gpu_memory_preservation=6,
use_teacache=True, mp4_crf=16, quality_radio="640x360", aspect_ratio="1:1"
):
global stream
# Map quality options to actual resolutions
quality_map = {
"360p": (640, 360),
"480p": (854, 480),
"540p": (960, 540),
"720p": (1280, 720),
"640x360": (640, 360), # fallback
}
# Map aspect ratio strings to width/height ratios
aspect_map = {
"1:1": (1, 1),
"3:4": (3, 4),
"4:3": (4, 3),
"16:9": (16, 9),
"9:16": (9, 16),
}
# Get target resolution based on selected quality
target_width, target_height = quality_map.get(quality_radio, (640, 360))
if t2v:
ar_w, ar_h = aspect_map.get(aspect_ratio, (1, 1))
# Recalculate based on aspect ratio
if ar_w >= ar_h:
target_height = int(round(target_width * ar_h / ar_w))
else:
target_width = int(round(target_height * ar_w / ar_h))
input_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255
print(f"Using blank white image for text-to-video mode, {target_width}x{target_height} ({aspect_ratio})")
else:
# Resize and crop input image to match selected resolution
H, W, C = input_image.shape
height, width = find_nearest_bucket(H, W, resolution=target_width)
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
stream = AsyncStream()
async_run(
worker, input_image, prompt, n_prompt, seed,
total_second_length, latent_window_size, steps,
cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
)
output_filename = None
while True:
flag, data = stream.output_queue.next()
if flag == 'file':
output_filename = data
yield (
output_filename,
gr.update(),
gr.update(),
gr.update(),
gr.update(interactive=False),
gr.update(interactive=True)
)
elif flag == 'progress':
preview, desc, html = data
yield (
gr.update(),
gr.update(visible=True, value=preview),
desc,
html,
gr.update(interactive=False),
gr.update(interactive=True)
)
elif flag == 'end':
yield (
output_filename,
gr.update(visible=False),
gr.update(),
'',
gr.update(interactive=True),
gr.update(interactive=False)
)
break
def end_process():
stream.input_queue.push('end')
quick_prompts = [
'The girl dances gracefully, with clear movements, full of charm.',
'A character doing some simple body movements.'
]
quick_prompts = [[x] for x in quick_prompts]
def make_custom_css():
base_progress_css = make_progress_bar_css()
extra_css = """
body {
background: #1a1b1e !important;
font-family: "Noto Sans", sans-serif;
color: #e0e0e0;
}
#title-container {
text-align: center;
padding: 20px 0;
margin-bottom: 30px;
}
#title-container h1 {
color: #4b9ffa;
font-size: 2.5rem;
margin: 0;
font-weight: 800;
}
#title-container p {
color: #e0e0e0;
}
.three-column-container {
display: flex;
gap: 20px;
min-height: 800px;
max-width: 1600px;
margin: 0 auto;
}
.settings-panel {
flex: 0 0 150px;
background: #2a2b2e;
padding: 12px;
border-radius: 15px;
border: 1px solid #3a3b3e;
}
.settings-panel .gr-slider {
width: calc(100% - 10px) !important;
}
.settings-panel label {
color: #e0e0e0 !important;
}
.settings-panel label span:first-child {
font-size: 0.9rem !important;
}
.main-panel {
flex: 1;
background: #2a2b2e;
padding: 20px;
border-radius: 15px;
border: 1px solid #3a3b3e;
display: flex;
flex-direction: column;
gap: 20px;
}
.output-panel {
flex: 1;
background: #2a2b2e;
padding: 20px;
border-radius: 15px;
border: 1px solid #3a3b3e;
display: flex;
flex-direction: column;
align-items: center; /* Center output content */
gap: 20px;
}
.output-panel > div {
width: 100%;
max-width: 640px; /* Limit width for better centering */
}
.settings-panel h3 {
color: #4b9ffa;
margin-bottom: 15px;
font-size: 1.1rem;
border-bottom: 2px solid #4b9ffa;
padding-bottom: 8px;
}
.prompt-container {
min-height: 200px;
}
.quick-prompts {
margin-top: 10px;
padding: 10px;
background: #1a1b1e;
border-radius: 10px;
}
.button-container {
display: flex;
gap: 10px;
margin: 15px 0;
justify-content: center;
width: 100%;
}
/* Override Gradio's default light theme */
.gr-box {
background: #2a2b2e !important;
border-color: #3a3b3e !important;
}
.gr-input, .gr-textbox {
background: #1a1b1e !important;
border-color: #3a3b3e !important;
color: #e0e0e0 !important;
}
.gr-form {
background: transparent !important;
border: none !important;
}
.gr-label {
color: #e0e0e0 !important;
}
.gr-button {
background: #4b9ffa !important;
color: white !important;
}
.gr-button.secondary-btn {
background: #ff4d4d !important;
}
"""
return base_progress_css + extra_css
css = make_custom_css()
block = gr.Blocks(css=css).queue()
with block:
with gr.Group(elem_id="title-container"):
gr.Markdown("<h1>FramePack</h1>")
gr.Markdown(
"""Generate amazing animations from a single image using AI.
Just upload an image, write a prompt, and watch the magic happen!"""
)
with gr.Row(elem_classes="three-column-container"):
# Left Column - Settings
with gr.Column(elem_classes="settings-panel"):
gr.Markdown("### Generation Settings")
with gr.Group():
total_second_length = gr.Slider(
label="Duration (Seconds)",
minimum=1,
maximum=10,
value=2,
step=1,
info='Length of generated video'
)
steps = gr.Slider(
label="Quality Steps",
minimum=1,
maximum=100,
value=15,
step=1,
info='25-30 recommended'
)
gs = gr.Slider(
label="Animation Strength",
minimum=1.0,
maximum=32.0,
value=10.0,
step=0.1,
info='8-12 recommended'
)
quality_radio = gr.Radio(
label="Video Quality (Resolution)",
choices=["360p", "480p", "540p", "720p"],
value="640x360",
info="Choose output video resolution"
)
# Aspect ratio dropdown, hidden by default
aspect_ratio = gr.Dropdown(
label="Aspect Ratio",
choices=["1:1", "3:4", "4:3", "16:9", "9:16"],
value="1:1",
visible=False,
info="Only applies to Text to Video mode"
)
gr.Markdown("### Advanced")
with gr.Group():
t2v = gr.Checkbox(
label='Text to Video Mode',
value=False,
info='Generate without input image'
)
use_teacache = gr.Checkbox(
label='Fast Mode',
value=True,
info='Faster but may affect details'
)
gpu_memory_preservation = gr.Slider(
label="VRAM Usage",
minimum=6,
maximum=128,
value=6,
step=1
)
seed = gr.Number(
label="Seed",
value=31337,
precision=0
)
# Hidden settings
n_prompt = gr.Textbox(visible=False, value="")
latent_window_size = gr.Slider(visible=False, value=9)
cfg = gr.Slider(visible=False, value=1.0)
rs = gr.Slider(visible=False, value=0.0)
mp4_crf = gr.Number(visible=False, value=16) # <-- Add this hidden component
# Middle Column - Main Content
with gr.Column(elem_classes="main-panel"):
input_image = gr.Image(
label="Upload Your Image",
type="numpy",
height=320
)
# Moved buttons here
with gr.Group(elem_classes="button-container"):
start_button = gr.Button(
value="▶️ Generate Animation",
elem_classes=["primary-btn"]
)
stop_button = gr.Button(
value="⏹️ Stop",
elem_classes=["secondary-btn"],
interactive=False
)
with gr.Group(elem_classes="prompt-container"):
prompt = gr.Textbox(
label="Describe the animation you want",
placeholder="E.g., The character dances gracefully with flowing movements...",
lines=4
)
with gr.Group(elem_classes="quick-prompts"):
gr.Markdown("### 💡 Quick Prompts")
example_quick_prompts = gr.Dataset(
samples=quick_prompts,
label='Click to use',
samples_per_page=3,
components=[prompt]
)
# Right Column - Output
with gr.Column(elem_classes="output-panel"):
preview_image = gr.Image(
label="Generation Preview",
height=200,
visible=False
)
result_video = gr.Video(
label="Generated Animation",
autoplay=True,
show_share_button=True,
height=400,
loop=True
)
with gr.Group(elem_classes="progress-container"):
progress_desc = gr.Markdown(
elem_classes='no-generating-animation'
)
progress_bar = gr.HTML(
elem_classes='no-generating-animation'
)
# Setup callbacks
ips = [
input_image, prompt, t2v, n_prompt, seed,
total_second_length, latent_window_size,
steps, cfg, gs, rs, gpu_memory_preservation,
use_teacache, mp4_crf, # Use the hidden component here
quality_radio, aspect_ratio
]
start_button.click(
fn=process,
inputs=ips,
outputs=[
result_video, preview_image,
progress_desc, progress_bar,
start_button, stop_button
]
)
stop_button.click(fn=end_process)
example_quick_prompts.click(
fn=lambda x: x[0],
inputs=[example_quick_prompts],
outputs=prompt,
show_progress=False,
queue=False
)
# Show/hide aspect ratio dropdown based on t2v checkbox
def show_aspect_ratio(t2v_checked):
return gr.update(visible=bool(t2v_checked))
t2v.change(
fn=show_aspect_ratio,
inputs=[t2v],
outputs=[aspect_ratio],
queue=False
)
block.launch(share=True)