import torch from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video from transformers import CLIPVisionModel import gradio as gr import tempfile import spaces from huggingface_hub import hf_hub_download import numpy as np from PIL import Image import random import logging import gc import time import hashlib from dataclasses import dataclass from typing import Optional, Tuple from functools import wraps # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 설정 관리 @dataclass class VideoGenerationConfig: model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" lora_repo_id: str = "Kijai/WanVideo_comfy" lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" mod_value: int = 32 default_height: int = 512 default_width: int = 896 max_area: float = 480.0 * 832.0 slider_min_h: int = 128 slider_max_h: int = 896 slider_min_w: int = 128 slider_max_w: int = 896 fixed_fps: int = 24 min_frames: int = 8 max_frames: int = 81 default_prompt: str = "make this image come alive, cinematic motion, smooth animation" default_negative_prompt: str = "static, blurred, low quality, watermark, text" config = VideoGenerationConfig() MAX_SEED = np.iinfo(np.int32).max # 성능 측정 데코레이터 def measure_time(func): @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) logger.info(f"{func.__name__} took {time.time()-start:.2f}s") return result return wrapper # 모델 관리자 class ModelManager: def __init__(self): self._pipe = None self._is_loaded = False @property def pipe(self): if not self._is_loaded: self._load_model() return self._pipe @measure_time def _load_model(self): logger.info("Loading model...") image_encoder = CLIPVisionModel.from_pretrained( config.model_id, subfolder="image_encoder", torch_dtype=torch.float32 ) vae = AutoencoderKLWan.from_pretrained( config.model_id, subfolder="vae", torch_dtype=torch.float32 ) self._pipe = WanImageToVideoPipeline.from_pretrained( config.model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 ) self._pipe.scheduler = UniPCMultistepScheduler.from_config( self._pipe.scheduler.config, flow_shift=8.0 ) self._pipe.to("cuda") causvid_path = hf_hub_download( repo_id=config.lora_repo_id, filename=config.lora_filename ) self._pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") self._pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) self._pipe.fuse_lora() self._is_loaded = True logger.info("Model loaded successfully") model_manager = ModelManager() # 비디오 생성기 클래스 class VideoGenerator: def __init__(self, config: VideoGenerationConfig, model_manager: ModelManager): self.config = config self.model_manager = model_manager def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]: orig_w, orig_h = image.size if orig_w <= 0 or orig_h <= 0: return self.config.default_height, self.config.default_width aspect_ratio = orig_h / orig_w calc_h = round(np.sqrt(self.config.max_area * aspect_ratio)) calc_w = round(np.sqrt(self.config.max_area / aspect_ratio)) calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value) calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value) new_h = int(np.clip(calc_h, self.config.slider_min_h, (self.config.slider_max_h // self.config.mod_value) * self.config.mod_value)) new_w = int(np.clip(calc_w, self.config.slider_min_w, (self.config.slider_max_w // self.config.mod_value) * self.config.mod_value)) return new_h, new_w def validate_inputs(self, image: Image.Image, prompt: str, height: int, width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]: if image is None: return False, "🖼️ Please upload an input image" if not prompt or len(prompt.strip()) == 0: return False, "✍️ Please provide a prompt" if len(prompt) > 500: return False, "⚠️ Prompt is too long (max 500 characters)" if duration < self.config.min_frames / self.config.fixed_fps: return False, f"⏱️ Duration too short (min {self.config.min_frames/self.config.fixed_fps:.1f}s)" if duration > self.config.max_frames / self.config.fixed_fps: return False, f"⏱️ Duration too long (max {self.config.max_frames/self.config.fixed_fps:.1f}s)" return True, None def generate_unique_filename(self, seed: int) -> str: timestamp = int(time.time()) unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}" hash_obj = hashlib.md5(unique_str.encode()) return f"video_{hash_obj.hexdigest()[:8]}.mp4" video_generator = VideoGenerator(config, model_manager) # Gradio 함수들 def handle_image_upload(image): if image is None: return gr.update(value=config.default_height), gr.update(value=config.default_width) try: if not isinstance(image, Image.Image): raise ValueError("Invalid image format") new_h, new_w = video_generator.calculate_dimensions(image) return gr.update(value=new_h), gr.update(value=new_w) except Exception as e: logger.error(f"Error processing image: {e}") gr.Warning("⚠️ Error processing image") return gr.update(value=config.default_height), gr.update(value=config.default_width) def get_duration(input_image, prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): if steps > 4 and duration_seconds > 2: return 90 elif steps > 4 or duration_seconds > 2: return 75 else: return 60 @spaces.GPU(duration=get_duration) @measure_time def generate_video(input_image, prompt, height, width, negative_prompt=config.default_negative_prompt, duration_seconds=2, guidance_scale=1, steps=4, seed=42, randomize_seed=False, progress=gr.Progress(track_tqdm=True)): progress(0.1, desc="🔍 Validating inputs...") # 입력 검증 is_valid, error_msg = video_generator.validate_inputs( input_image, prompt, height, width, duration_seconds, steps ) if not is_valid: raise gr.Error(error_msg) try: progress(0.2, desc="🎯 Preparing image...") target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value) target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value) num_frames = np.clip(int(round(duration_seconds * config.fixed_fps)), config.min_frames, config.max_frames) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) progress(0.3, desc="🎨 Loading model...") pipe = model_manager.pipe progress(0.4, desc="🎬 Generating video frames...") with torch.inference_mode(): output_frames_list = pipe( image=resized_image, prompt=prompt, negative_prompt=negative_prompt, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=float(guidance_scale), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed) ).frames[0] progress(0.9, desc="💾 Saving video...") filename = video_generator.generate_unique_filename(current_seed) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name export_to_video(output_frames_list, video_path, fps=config.fixed_fps) progress(1.0, desc="✨ Complete!") return video_path, current_seed finally: # 메모리 정리 if 'output_frames_list' in locals(): del output_frames_list gc.collect() torch.cuda.empty_cache() # CSS 스타일 css = """ .container { max-width: 1200px; margin: auto; padding: 20px; } .header { text-align: center; margin-bottom: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 40px; border-radius: 20px; color: white; box-shadow: 0 10px 30px rgba(0,0,0,0.2); } .header h1 { font-size: 3em; margin-bottom: 10px; text-shadow: 2px 2px 4px rgba(0,0,0,0.3); } .header p { font-size: 1.2em; opacity: 0.95; } .main-content { background: rgba(255, 255, 255, 0.95); border-radius: 20px; padding: 30px; box-shadow: 0 5px 20px rgba(0,0,0,0.1); backdrop-filter: blur(10px); } .input-section { background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); padding: 25px; border-radius: 15px; margin-bottom: 20px; } .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; font-size: 1.3em; padding: 15px 40px; border-radius: 30px; border: none; cursor: pointer; transition: all 0.3s ease; box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); width: 100%; margin-top: 20px; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6); } .video-output { background: #f8f9fa; padding: 20px; border-radius: 15px; text-align: center; min-height: 400px; display: flex; align-items: center; justify-content: center; } .accordion { background: rgba(255, 255, 255, 0.7); border-radius: 10px; margin-top: 15px; padding: 15px; } .slider-container { background: rgba(255, 255, 255, 0.5); padding: 15px; border-radius: 10px; margin: 10px 0; } body { background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); background-size: 400% 400%; animation: gradient 15s ease infinite; } @keyframes gradient { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } .gr-button-secondary { background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); } .footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # Gradio UI with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_classes="container"): # Header gr.HTML("""
Transform your images into captivating videos with Wan 2.1 + CausVid LoRA