seawolf2357's picture
Update app.py
bfa6fb3 verified
raw
history blame
19 kB
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
import gradio as gr
import tempfile
import spaces
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import random
import logging
import gc
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ๋ชจ๋ธ ์„ค์ •
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
LORA_REPO_ID = "Kijai/WanVideo_comfy"
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
# ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •
MOD_VALUE = 32
DEFAULT_H_SLIDER_VALUE = 512
DEFAULT_W_SLIDER_VALUE = 512 # Zero GPU๋ฅผ ์œ„ํ•ด ์ •์‚ฌ๊ฐํ˜• ๊ธฐ๋ณธ๊ฐ’
NEW_FORMULA_MAX_AREA = 480.0 * 832.0
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt = "static, blurred, low quality, watermark, text"
# ๋ชจ๋ธ ๊ธ€๋กœ๋ฒŒ ๋กœ๋”ฉ
logger.info("Loading model components...")
image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.to("cuda")
# LoRA ๋กœ๋”ฉ
try:
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
pipe.fuse_lora()
logger.info("LoRA loaded successfully")
except Exception as e:
logger.warning(f"LoRA loading failed: {e}")
# ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ํ™œ์„ฑํ™”
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
logger.info("Model loaded and ready")
def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
min_slider_h, max_slider_h,
min_slider_w, max_slider_w,
default_h, default_w):
orig_w, orig_h = pil_image.size
if orig_w <= 0 or orig_h <= 0:
return default_h, default_w
aspect_ratio = orig_h / orig_w
# Zero GPU๋ฅผ ์œ„ํ•œ ๋ณด์ˆ˜์ ์ธ ๊ณ„์‚ฐ
if hasattr(spaces, 'GPU'):
# ๋” ์ž‘์€ max_area ์‚ฌ์šฉ
calculation_max_area = min(calculation_max_area, 320.0 * 320.0)
calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
# Zero GPU ํ™˜๊ฒฝ์—์„œ ์ถ”๊ฐ€ ์ œํ•œ
if hasattr(spaces, 'GPU'):
max_slider_h = min(max_slider_h, 640)
max_slider_w = min(max_slider_w, 640)
new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
return new_h, new_w
def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
if uploaded_pil_image is None:
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
try:
new_h, new_w = _calculate_new_dimensions_wan(
uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
)
return gr.update(value=new_h), gr.update(value=new_w)
except Exception as e:
gr.Warning("Error attempting to calculate new dimensions")
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
def get_duration(input_image, prompt, height, width,
negative_prompt, duration_seconds,
guidance_scale, steps,
seed, randomize_seed,
progress):
# Zero GPU๋ฅผ ์œ„ํ•œ ๋ณด์ˆ˜์ ์ธ ์‹œ๊ฐ„ ํ• ๋‹น
base_time = 60
if hasattr(spaces, 'GPU'):
# Zero GPU ํ™˜๊ฒฝ์—์„œ ๋” ๋งŽ์€ ์‹œ๊ฐ„ ํ• ๋‹น
if steps > 4 and duration_seconds > 2:
return 90
elif steps > 4 or duration_seconds > 2:
return 80
else:
return 70
else:
# ์ผ๋ฐ˜ GPU ํ™˜๊ฒฝ
if steps > 4 and duration_seconds > 2:
return 90
elif steps > 4 or duration_seconds > 2:
return 75
else:
return 60
@spaces.GPU(duration=get_duration)
def generate_video(input_image, prompt, height, width,
negative_prompt=default_negative_prompt, duration_seconds = 2,
guidance_scale = 1, steps = 4,
seed = 42, randomize_seed = False,
progress=gr.Progress(track_tqdm=True)):
if input_image is None:
raise gr.Error("Please upload an input image.")
# Zero GPU ํ™˜๊ฒฝ์—์„œ ์ถ”๊ฐ€ ๊ฒ€์ฆ
if hasattr(spaces, 'GPU'):
# ํ”ฝ์…€ ์ œํ•œ
max_pixels = 409600 # 640x640
if height * width > max_pixels:
raise gr.Error(f"Resolution too high for Zero GPU. Maximum {max_pixels:,} pixels (e.g., 640ร—640)")
# Duration ์ œํ•œ
if duration_seconds > 2.5:
duration_seconds = 2.5
gr.Warning("Duration limited to 2.5s in Zero GPU environment")
# Steps ์ œํ•œ
if steps > 8:
steps = 8
gr.Warning("Steps limited to 8 in Zero GPU environment")
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
# Zero GPU์—์„œ ํ”„๋ ˆ์ž„ ์ˆ˜ ์ถ”๊ฐ€ ์ œํ•œ
if hasattr(spaces, 'GPU'):
max_frames_zerogpu = int(2.5 * FIXED_FPS) # 2.5์ดˆ
num_frames = min(num_frames, max_frames_zerogpu)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
logger.info(f"Generating video: {target_h}x{target_w}, {num_frames} frames, seed={current_seed}")
# ์ด๋ฏธ์ง€ ๋ฆฌ์‚ฌ์ด์ฆˆ
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
try:
with torch.inference_mode():
output_frames_list = pipe(
image=resized_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=target_h,
width=target_w,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
except torch.cuda.OutOfMemoryError:
gc.collect()
torch.cuda.empty_cache()
raise gr.Error("GPU out of memory. Try smaller resolution or shorter duration.")
except Exception as e:
logger.error(f"Generation failed: {e}")
raise gr.Error(f"Video generation failed: {str(e)[:100]}")
# ๋น„๋””์˜ค ์ €์žฅ
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del output_frames_list
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return video_path, current_seed
# CSS ์Šคํƒ€์ผ (๊ธฐ์กด UI ์œ ์ง€)
css = """
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 30px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 40px;
border-radius: 20px;
color: white;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
position: relative;
overflow: hidden;
}
.header::before {
content: '';
position: absolute;
top: -50%;
left: -50%;
width: 200%;
height: 200%;
background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%);
animation: pulse 4s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { transform: scale(1); opacity: 0.5; }
50% { transform: scale(1.1); opacity: 0.8; }
}
.header h1 {
font-size: 3em;
margin-bottom: 10px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
position: relative;
z-index: 1;
}
.header p {
font-size: 1.2em;
opacity: 0.95;
position: relative;
z-index: 1;
}
.gpu-status {
position: absolute;
top: 10px;
right: 10px;
background: rgba(0,0,0,0.3);
padding: 5px 15px;
border-radius: 20px;
font-size: 0.8em;
}
.main-content {
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
padding: 30px;
box-shadow: 0 5px 20px rgba(0,0,0,0.1);
backdrop-filter: blur(10px);
}
.input-section {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
padding: 25px;
border-radius: 15px;
margin-bottom: 20px;
}
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
font-size: 1.3em;
padding: 15px 40px;
border-radius: 30px;
border: none;
cursor: pointer;
transition: all 0.3s ease;
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
width: 100%;
margin-top: 20px;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6);
}
.generate-btn:active {
transform: translateY(0);
}
.video-output {
background: #f8f9fa;
padding: 20px;
border-radius: 15px;
text-align: center;
min-height: 400px;
display: flex;
align-items: center;
justify-content: center;
}
.accordion {
background: rgba(255, 255, 255, 0.7);
border-radius: 10px;
margin-top: 15px;
padding: 15px;
}
.slider-container {
background: rgba(255, 255, 255, 0.5);
padding: 15px;
border-radius: 10px;
margin: 10px 0;
}
body {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradient 15s ease infinite;
}
@keyframes gradient {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.warning-box {
background: rgba(255, 193, 7, 0.1);
border: 1px solid rgba(255, 193, 7, 0.3);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
color: #856404;
font-size: 0.9em;
}
.info-box {
background: rgba(52, 152, 219, 0.1);
border: 1px solid rgba(52, 152, 219, 0.3);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
color: #2c5282;
font-size: 0.9em;
}
.footer {
text-align: center;
margin-top: 30px;
color: #666;
font-size: 0.9em;
}
"""
# Gradio UI (๊ธฐ์กด ๊ตฌ์กฐ ์œ ์ง€)
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_classes="container"):
# Header with GPU status
gr.HTML("""
<div class="header">
<h1>๐ŸŽฌ AI Video Magic Studio</h1>
<p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p>
<div class="gpu-status">๐Ÿ–ฅ๏ธ Zero GPU Optimized</div>
</div>
""")
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ฒฝ๊ณ 
if hasattr(spaces, 'GPU'):
gr.HTML("""
<div class="warning-box">
<strong>๐Ÿ’ก Zero GPU Performance Tips:</strong>
<ul style="margin: 5px 0; padding-left: 20px;">
<li>Maximum duration: 2.5 seconds</li>
<li>Maximum resolution: 640ร—640 pixels</li>
<li>Recommended: 512ร—512 at 2 seconds</li>
<li>Use 4-6 steps for optimal speed/quality balance</li>
<li>Processing time: ~60-90 seconds</li>
</ul>
</div>
""")
# ์ •๋ณด ๋ฐ•์Šค
gr.HTML("""
<div class="info-box">
<strong>๐ŸŽฏ Quick Start Guide:</strong>
<ol style="margin: 5px 0; padding-left: 20px;">
<li>Upload your image - AI will calculate optimal dimensions</li>
<li>Enter a creative prompt or use the default</li>
<li>Adjust duration (2s recommended for best results)</li>
<li>Click Generate and wait for completion</li>
</ol>
</div>
""")
with gr.Row(elem_classes="main-content"):
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“ธ Input Settings")
with gr.Column(elem_classes="input-section"):
input_image = gr.Image(
type="pil",
label="๐Ÿ–ผ๏ธ Upload Your Image",
elem_classes="image-upload"
)
prompt_input = gr.Textbox(
label="โœจ Animation Prompt",
value=default_prompt_i2v,
placeholder="Describe how you want your image to move...",
lines=2
)
duration_input = gr.Slider(
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1) if not hasattr(spaces, 'GPU') else 2.5,
step=0.1,
value=2,
label=f"โฑ๏ธ Video Duration (seconds) - Clamped to {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps",
elem_classes="slider-container"
)
with gr.Accordion("๐ŸŽ›๏ธ Advanced Settings", open=False, elem_classes="accordion"):
negative_prompt = gr.Textbox(
label="๐Ÿšซ Negative Prompt",
value=default_negative_prompt,
lines=3
)
with gr.Row():
seed = gr.Slider(
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
label="๐ŸŽฒ Seed"
)
randomize_seed = gr.Checkbox(
label="๐Ÿ”€ Randomize",
value=True
)
with gr.Row():
height_slider = gr.Slider(
minimum=SLIDER_MIN_H,
maximum=SLIDER_MAX_H if not hasattr(spaces, 'GPU') else 640,
step=MOD_VALUE,
value=DEFAULT_H_SLIDER_VALUE,
label=f"๐Ÿ“ Height (multiple of {MOD_VALUE})"
)
width_slider = gr.Slider(
minimum=SLIDER_MIN_W,
maximum=SLIDER_MAX_W if not hasattr(spaces, 'GPU') else 640,
step=MOD_VALUE,
value=DEFAULT_W_SLIDER_VALUE,
label=f"๐Ÿ“ Width (multiple of {MOD_VALUE})"
)
steps_slider = gr.Slider(
minimum=1,
maximum=30 if not hasattr(spaces, 'GPU') else 8,
step=1,
value=4,
label="๐Ÿ”ง Quality Steps (4-6 recommended)"
)
guidance_scale = gr.Slider(
minimum=0.0,
maximum=20.0,
step=0.5,
value=1.0,
label="๐ŸŽฏ Guidance Scale",
visible=False
)
generate_btn = gr.Button(
"๐ŸŽฌ Generate Video",
variant="primary",
elem_classes="generate-btn"
)
with gr.Column(scale=1):
gr.Markdown("### ๐ŸŽฅ Generated Video")
video_output = gr.Video(
label="",
autoplay=True,
elem_classes="video-output"
)
gr.HTML("""
<div class="footer">
<p>๐Ÿ’ก Tip: For best results, use clear images with good lighting and distinct subjects</p>
</div>
""")
# Examples
gr.Examples(
examples=[
["peng.png", "a penguin playfully dancing in the snow, Antarctica", 512, 512],
["forg.jpg", "the frog jumps around", 448, 576],
],
inputs=[input_image, prompt_input, height_slider, width_slider],
outputs=[video_output, seed],
fn=generate_video,
cache_examples=False # ์บ์‹œ ๋น„ํ™œ์„ฑํ™”๋กœ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
)
# ๊ฐœ์„ ์‚ฌํ•ญ ์š”์•ฝ
gr.HTML("""
<div style="background: rgba(255,255,255,0.9); border-radius: 10px; padding: 15px; margin-top: 20px; font-size: 0.8em; text-align: center;">
<p style="margin: 0; color: #666;">
<strong style="color: #667eea;">Powered by:</strong>
Wan 2.1 I2V (14B) + CausVid LoRA โ€ข ๐Ÿš€ 4-8 steps fast inference โ€ข ๐ŸŽฌ Up to 81 frames
</p>
</div>
""")
# Event handlers
input_image.upload(
fn=handle_image_upload_for_dims_wan,
inputs=[input_image, height_slider, width_slider],
outputs=[height_slider, width_slider]
)
input_image.clear(
fn=handle_image_upload_for_dims_wan,
inputs=[input_image, height_slider, width_slider],
outputs=[height_slider, width_slider]
)
generate_btn.click(
fn=generate_video,
inputs=[
input_image, prompt_input, height_slider, width_slider,
negative_prompt, duration_input, guidance_scale,
steps_slider, seed, randomize_seed
],
outputs=[video_output, seed]
)
if __name__ == "__main__":
demo.queue().launch()