Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler | |
from diffusers.utils import export_to_video | |
from transformers import CLIPVisionModel | |
import gradio as gr | |
import tempfile | |
import spaces | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
from PIL import Image | |
import random | |
import logging | |
import gc | |
import time | |
import hashlib | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
from functools import wraps | |
import threading | |
import os | |
# GPU ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ์ค์ | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # ๋ ์์ ์ฒญํฌ ์ฌ์ฉ | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ์ค์ ๊ด๋ฆฌ | |
class VideoGenerationConfig: | |
model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
lora_repo_id: str = "Kijai/WanVideo_comfy" | |
lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
mod_value: int = 32 | |
# Zero GPU๋ฅผ ์ํ ๋งค์ฐ ๋ณด์์ ์ธ ๊ธฐ๋ณธ๊ฐ | |
default_height: int = 320 | |
default_width: int = 320 | |
max_area: float = 320.0 * 320.0 # Zero GPU์ ์ต์ ํ | |
slider_min_h: int = 128 | |
slider_max_h: int = 512 # ๋ ๋ฎ์ ์ต๋๊ฐ | |
slider_min_w: int = 128 | |
slider_max_w: int = 512 # ๋ ๋ฎ์ ์ต๋๊ฐ | |
fixed_fps: int = 24 | |
min_frames: int = 8 | |
max_frames: int = 30 # ๋ ๋ฎ์ ์ต๋ ํ๋ ์ (1.25์ด) | |
default_prompt: str = "make this image move, smooth motion" | |
default_negative_prompt: str = "static, blur" | |
# GPU ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ ์ค์ | |
enable_model_cpu_offload: bool = True | |
enable_vae_slicing: bool = True | |
enable_vae_tiling: bool = True | |
def max_duration(self): | |
"""์ต๋ ํ์ฉ duration (์ด)""" | |
return self.max_frames / self.fixed_fps | |
def min_duration(self): | |
"""์ต์ ํ์ฉ duration (์ด)""" | |
return self.min_frames / self.fixed_fps | |
config = VideoGenerationConfig() | |
MAX_SEED = np.iinfo(np.int32).max | |
# ๊ธ๋ก๋ฒ ๋ณ์ | |
pipe = None | |
generation_lock = threading.Lock() | |
# ์ฑ๋ฅ ์ธก์ ๋ฐ์ฝ๋ ์ดํฐ | |
def measure_time(func): | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
result = func(*args, **kwargs) | |
logger.info(f"{func.__name__} took {time.time()-start:.2f}s") | |
return result | |
return wrapper | |
# GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์ | |
def clear_gpu_memory(): | |
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ (Zero GPU ์์ )""" | |
gc.collect() | |
if torch.cuda.is_available(): | |
try: | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
except: | |
pass | |
# ๋น๋์ค ์์ฑ๊ธฐ ํด๋์ค | |
class VideoGenerator: | |
def __init__(self, config: VideoGenerationConfig): | |
self.config = config | |
def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]: | |
orig_w, orig_h = image.size | |
if orig_w <= 0 or orig_h <= 0: | |
return self.config.default_height, self.config.default_width | |
aspect_ratio = orig_h / orig_w | |
# Zero GPU์ ์ต์ ํ๋ ๋งค์ฐ ์์ ํด์๋ | |
max_area = 320.0 * 320.0 # 102,400 ํฝ์ | |
# ์ข ํก๋น๊ฐ ๋๋ฌด ๊ทน๋จ์ ์ธ ๊ฒฝ์ฐ ์กฐ์ | |
if aspect_ratio > 2.0: | |
aspect_ratio = 2.0 | |
elif aspect_ratio < 0.5: | |
aspect_ratio = 0.5 | |
calc_h = round(np.sqrt(max_area * aspect_ratio)) | |
calc_w = round(np.sqrt(max_area / aspect_ratio)) | |
# mod_value์ ๋ง์ถค | |
calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value) | |
calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value) | |
# ์ต๋ 512๋ก ์ ํ | |
new_h = int(np.clip(calc_h, self.config.slider_min_h, 512)) | |
new_w = int(np.clip(calc_w, self.config.slider_min_w, 512)) | |
# mod_value์ ๋ง์ถค | |
new_h = (new_h // self.config.mod_value) * self.config.mod_value | |
new_w = (new_w // self.config.mod_value) * self.config.mod_value | |
# ์ต์ข ํฝ์ ์ ํ์ธ | |
if new_h * new_w > 102400: # 320x320 | |
# ๋น์จ์ ์ ์งํ๋ฉด์ ์ถ์ | |
scale = np.sqrt(102400 / (new_h * new_w)) | |
new_h = int((new_h * scale) // self.config.mod_value) * self.config.mod_value | |
new_w = int((new_w * scale) // self.config.mod_value) * self.config.mod_value | |
return new_h, new_w | |
def validate_inputs(self, image: Image.Image, prompt: str, height: int, | |
width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]: | |
if image is None: | |
return False, "๐ผ๏ธ Please upload an input image" | |
if not prompt or len(prompt.strip()) == 0: | |
return False, "โ๏ธ Please provide a prompt" | |
if len(prompt) > 200: # ๋ ์งง์ ํ๋กฌํํธ ์ ํ | |
return False, "โ ๏ธ Prompt is too long (max 200 characters)" | |
# Zero GPU์ ์ต์ ํ๋ ์ ํ | |
if duration < 0.3: | |
return False, "โฑ๏ธ Duration too short (min 0.3s)" | |
if duration > 1.2: # ๋ ์งง์ ์ต๋ duration | |
return False, "โฑ๏ธ Duration too long (max 1.2s for stability)" | |
# ํฝ์ ์ ์ ํ (๋ ๋ณด์์ ์ผ๋ก) | |
max_pixels = 320 * 320 # 102,400 ํฝ์ | |
if height * width > max_pixels: | |
return False, f"๐ Total pixels limited to {max_pixels:,} (e.g., 320ร320, 256ร384)" | |
if height > 512 or width > 512: # ๋ ๋ฎ์ ์ต๋๊ฐ | |
return False, "๐ Maximum dimension is 512 pixels" | |
# ์ข ํก๋น ์ฒดํฌ | |
aspect_ratio = max(height/width, width/height) | |
if aspect_ratio > 2.0: | |
return False, "๐ Aspect ratio too extreme (max 2:1 or 1:2)" | |
if steps > 5: # ๋ ๋ฎ์ ์ต๋ ์คํ | |
return False, "๐ง Maximum 5 steps in Zero GPU environment" | |
return True, None | |
def generate_unique_filename(self, seed: int) -> str: | |
timestamp = int(time.time()) | |
unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}" | |
hash_obj = hashlib.md5(unique_str.encode()) | |
return f"video_{hash_obj.hexdigest()[:8]}.mp4" | |
video_generator = VideoGenerator(config) | |
# Gradio ํจ์๋ค | |
def handle_image_upload(image): | |
if image is None: | |
return gr.update(value=config.default_height), gr.update(value=config.default_width) | |
try: | |
if not isinstance(image, Image.Image): | |
raise ValueError("Invalid image format") | |
new_h, new_w = video_generator.calculate_dimensions(image) | |
return gr.update(value=new_h), gr.update(value=new_w) | |
except Exception as e: | |
logger.error(f"Error processing image: {e}") | |
gr.Warning("โ ๏ธ Error processing image") | |
return gr.update(value=config.default_height), gr.update(value=config.default_width) | |
def get_duration(input_image, prompt, height, width, negative_prompt, | |
duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): | |
# Zero GPU ํ๊ฒฝ์์ ๋งค์ฐ ๋ณด์์ ์ธ ์๊ฐ ํ ๋น | |
base_duration = 50 # ๊ธฐ๋ณธ 50์ด๋ก ์ฆ๊ฐ | |
# ํฝ์ ์์ ๋ฐ๋ฅธ ์ถ๊ฐ ์๊ฐ | |
pixels = height * width | |
if pixels > 147456: # 384x384 ์ด์ | |
base_duration += 20 | |
elif pixels > 100000: # ~316x316 ์ด์ | |
base_duration += 10 | |
# ์คํ ์์ ๋ฐ๋ฅธ ์ถ๊ฐ ์๊ฐ | |
if steps > 4: | |
base_duration += 15 | |
elif steps > 2: | |
base_duration += 10 | |
# ์ข ํก๋น๊ฐ ๊ทน๋จ์ ์ธ ๊ฒฝ์ฐ ์ถ๊ฐ ์๊ฐ | |
aspect_ratio = max(height/width, width/height) | |
if aspect_ratio > 1.5: # 3:2 ์ด์์ ๋น์จ | |
base_duration += 10 | |
# ์ต๋ 90์ด๋ก ์ ํ | |
return min(base_duration, 90) | |
def generate_video(input_image, prompt, height, width, | |
negative_prompt=config.default_negative_prompt, | |
duration_seconds=0.8, guidance_scale=1, steps=3, | |
seed=42, randomize_seed=False, | |
progress=gr.Progress(track_tqdm=True)): | |
global pipe | |
# ๋์ ์คํ ๋ฐฉ์ง | |
if not generation_lock.acquire(blocking=False): | |
raise gr.Error("โณ Another video is being generated. Please wait...") | |
try: | |
progress(0.05, desc="๐ Validating inputs...") | |
logger.info(f"Starting generation - Resolution: {height}x{width}, Duration: {duration_seconds}s, Steps: {steps}") | |
# ์ ๋ ฅ ๊ฒ์ฆ | |
is_valid, error_msg = video_generator.validate_inputs( | |
input_image, prompt, height, width, duration_seconds, steps | |
) | |
if not is_valid: | |
logger.warning(f"Validation failed: {error_msg}") | |
raise gr.Error(error_msg) | |
# ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
clear_gpu_memory() | |
progress(0.1, desc="๐ Loading model...") | |
# ๋ชจ๋ธ ๋ก๋ฉ (GPU ํจ์ ๋ด์์) | |
if pipe is None: | |
try: | |
logger.info("Loading model components...") | |
# ์ปดํฌ๋ํธ ๋ก๋ | |
image_encoder = CLIPVisionModel.from_pretrained( | |
config.model_id, | |
subfolder="image_encoder", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
vae = AutoencoderKLWan.from_pretrained( | |
config.model_id, | |
subfolder="vae", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
pipe = WanImageToVideoPipeline.from_pretrained( | |
config.model_id, | |
vae=vae, | |
image_encoder=image_encoder, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
# ์ค์ผ์ค๋ฌ ์ค์ | |
pipe.scheduler = UniPCMultistepScheduler.from_config( | |
pipe.scheduler.config, flow_shift=8.0 | |
) | |
# LoRA ๋ก๋ ๊ฑด๋๋ฐ๊ธฐ (์์ ์ฑ์ ์ํด) | |
logger.info("Skipping LoRA for stability") | |
# GPU๋ก ์ด๋ | |
pipe.to("cuda") | |
# ์ต์ ํ ํ์ฑํ | |
pipe.enable_vae_slicing() | |
pipe.enable_vae_tiling() | |
# ๋ชจ๋ธ CPU ์คํ๋ก๋ ํ์ฑํ (๋ฉ๋ชจ๋ฆฌ ์ ์ฝ) | |
pipe.enable_model_cpu_offload() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Model loading failed: {e}") | |
raise gr.Error("Failed to load model") | |
progress(0.3, desc="๐ฏ Preparing image...") | |
# ์ด๋ฏธ์ง ์ค๋น | |
target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value) | |
target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value) | |
# ํ๋ ์ ์ ๊ณ์ฐ (๋งค์ฐ ๋ณด์์ ) | |
num_frames = min( | |
int(round(duration_seconds * config.fixed_fps)), | |
24 # ์ต๋ 24ํ๋ ์ (1์ด) | |
) | |
num_frames = max(8, num_frames) # ์ต์ 8ํ๋ ์ | |
logger.info(f"Generating {num_frames} frames at {target_h}x{target_w}") | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
# ์ด๋ฏธ์ง ๋ฆฌ์ฌ์ด์ฆ | |
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) | |
progress(0.4, desc="๐ฌ Generating video...") | |
# ๋น๋์ค ์์ฑ | |
with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True, dtype=torch.float16): | |
try: | |
# ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ํ ์ค์ | |
torch.cuda.empty_cache() | |
# ์์ฑ ํ๋ผ๋ฏธํฐ ์ต์ ํ | |
output_frames_list = pipe( | |
image=resized_image, | |
prompt=prompt[:150], # ํ๋กฌํํธ ๊ธธ์ด ์ ํ | |
negative_prompt=negative_prompt[:50] if negative_prompt else "", | |
height=target_h, | |
width=target_w, | |
num_frames=num_frames, | |
guidance_scale=float(guidance_scale), | |
num_inference_steps=int(steps), | |
generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
return_dict=True, | |
# ์ถ๊ฐ ์ต์ ํ ํ๋ผ๋ฏธํฐ | |
output_type="pil" | |
).frames[0] | |
logger.info("Video generation completed successfully") | |
except torch.cuda.OutOfMemoryError: | |
logger.error("GPU OOM error") | |
clear_gpu_memory() | |
raise gr.Error("๐พ GPU out of memory. Try smaller dimensions (256x256 recommended).") | |
except RuntimeError as e: | |
if "out of memory" in str(e).lower(): | |
logger.error("Runtime OOM error") | |
clear_gpu_memory() | |
raise gr.Error("๐พ GPU memory error. Please try again with smaller settings.") | |
else: | |
logger.error(f"Runtime error: {e}") | |
raise gr.Error(f"โ Generation failed: {str(e)[:50]}") | |
except Exception as e: | |
logger.error(f"Generation error: {type(e).__name__}: {e}") | |
raise gr.Error(f"โ Generation failed. Try reducing resolution or steps.") | |
progress(0.9, desc="๐พ Saving video...") | |
# ๋น๋์ค ์ ์ฅ | |
try: | |
filename = video_generator.generate_unique_filename(current_seed) | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
video_path = tmpfile.name | |
export_to_video(output_frames_list, video_path, fps=config.fixed_fps) | |
logger.info(f"Video saved: {video_path}") | |
except Exception as e: | |
logger.error(f"Save error: {e}") | |
raise gr.Error("Failed to save video") | |
progress(1.0, desc="โจ Complete!") | |
logger.info(f"Video generated: {num_frames} frames, {target_h}x{target_w}") | |
# ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
del output_frames_list | |
del resized_image | |
torch.cuda.empty_cache() | |
gc.collect() | |
return video_path, current_seed | |
except gr.Error: | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error: {type(e).__name__}: {e}") | |
raise gr.Error(f"โ Unexpected error. Please try again with smaller settings.") | |
finally: | |
generation_lock.release() | |
clear_gpu_memory() | |
# CSS | |
css = """ | |
.container { | |
max-width: 1000px; | |
margin: auto; | |
padding: 20px; | |
} | |
.header { | |
text-align: center; | |
margin-bottom: 20px; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
padding: 30px; | |
border-radius: 15px; | |
color: white; | |
box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
} | |
.header h1 { | |
font-size: 2.5em; | |
margin-bottom: 10px; | |
} | |
.warning-box { | |
background: #fff3cd; | |
border: 1px solid #ffeaa7; | |
border-radius: 8px; | |
padding: 12px; | |
margin: 10px 0; | |
color: #856404; | |
font-size: 0.9em; | |
} | |
.generate-btn { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
font-size: 1.2em; | |
padding: 12px 30px; | |
border-radius: 25px; | |
border: none; | |
cursor: pointer; | |
width: 100%; | |
margin-top: 15px; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); | |
} | |
""" | |
# Gradio UI | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
with gr.Column(elem_classes="container"): | |
# Header | |
gr.HTML(""" | |
<div class="header"> | |
<h1>๐ฌ AI Video Generator</h1> | |
<p>Transform images into videos with Wan 2.1 (Zero GPU Optimized)</p> | |
</div> | |
""") | |
# ๊ฒฝ๊ณ | |
gr.HTML(""" | |
<div class="warning-box"> | |
<strong>โก Zero GPU Strict Limitations:</strong> | |
<ul style="margin: 5px 0; padding-left: 20px;"> | |
<li>Max resolution: 320ร320 (recommended 256ร256)</li> | |
<li>Max duration: 1.2 seconds</li> | |
<li>Max steps: 5 (2-3 recommended)</li> | |
<li>Processing time: ~50-80 seconds</li> | |
<li>Please wait for completion before next generation</li> | |
</ul> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image( | |
type="pil", | |
label="๐ผ๏ธ Upload Image" | |
) | |
prompt_input = gr.Textbox( | |
label="โจ Animation Prompt", | |
value=config.default_prompt, | |
placeholder="Describe the motion...", | |
lines=2, | |
max_lines=3 | |
) | |
duration_input = gr.Slider( | |
minimum=0.3, | |
maximum=1.2, | |
step=0.1, | |
value=0.8, | |
label="โฑ๏ธ Duration (seconds)" | |
) | |
with gr.Accordion("โ๏ธ Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value=config.default_negative_prompt, | |
lines=1 | |
) | |
with gr.Row(): | |
height_slider = gr.Slider( | |
minimum=128, | |
maximum=512, | |
step=32, | |
value=256, | |
label="Height" | |
) | |
width_slider = gr.Slider( | |
minimum=128, | |
maximum=512, | |
step=32, | |
value=256, | |
label="Width" | |
) | |
steps_slider = gr.Slider( | |
minimum=1, | |
maximum=5, | |
step=1, | |
value=2, | |
label="Steps (2-3 recommended)" | |
) | |
with gr.Row(): | |
seed = gr.Slider( | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
label="Seed" | |
) | |
randomize_seed = gr.Checkbox( | |
label="Random", | |
value=True | |
) | |
guidance_scale = gr.Slider( | |
minimum=0.0, | |
maximum=5.0, | |
step=0.5, | |
value=1.0, | |
label="Guidance Scale", | |
visible=False | |
) | |
generate_btn = gr.Button( | |
"๐ฌ Generate Video", | |
variant="primary", | |
elem_classes="generate-btn" | |
) | |
with gr.Column(scale=1): | |
video_output = gr.Video( | |
label="Generated Video", | |
autoplay=True | |
) | |
gr.Markdown(""" | |
### ๐ก Tips for Zero GPU: | |
- **Best**: 256ร256 resolution | |
- **Safe**: 2-3 steps only | |
- **Duration**: 0.8s is optimal | |
- **Prompts**: Keep short and simple | |
- **Important**: Wait for completion! | |
### โ ๏ธ If GPU stops: | |
- Reduce resolution to 256ร256 | |
- Use only 2 steps | |
- Keep duration under 1 second | |
- Avoid extreme aspect ratios | |
""") | |
# Event handlers | |
input_image.upload( | |
fn=handle_image_upload, | |
inputs=[input_image], | |
outputs=[height_slider, width_slider] | |
) | |
generate_btn.click( | |
fn=generate_video, | |
inputs=[ | |
input_image, prompt_input, height_slider, width_slider, | |
negative_prompt, duration_input, guidance_scale, | |
steps_slider, seed, randomize_seed | |
], | |
outputs=[video_output, seed] | |
) | |
if __name__ == "__main__": | |
logger.info("Starting app in Zero GPU environment") | |
demo.queue(max_size=2) # ์์ ํ ์ฌ์ด์ฆ | |
demo.launch() |