import torch from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video from transformers import CLIPVisionModel import gradio as gr import tempfile import spaces from huggingface_hub import hf_hub_download import numpy as np from PIL import Image import random import logging import gc import time import hashlib from dataclasses import dataclass from typing import Optional, Tuple from functools import wraps import threading import os # GPU 메모리 관리 설정 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 설정 관리 @dataclass class VideoGenerationConfig: model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" lora_repo_id: str = "Kijai/WanVideo_comfy" lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" mod_value: int = 32 default_height: int = 512 default_width: int = 896 max_area: float = 480.0 * 832.0 slider_min_h: int = 128 slider_max_h: int = 896 slider_min_w: int = 128 slider_max_w: int = 896 fixed_fps: int = 24 min_frames: int = 8 max_frames: int = 81 default_prompt: str = "make this image come alive, cinematic motion, smooth animation" default_negative_prompt: str = "static, blurred, low quality, watermark, text" # GPU 메모리 최적화 설정 enable_model_cpu_offload: bool = True enable_vae_slicing: bool = True enable_vae_tiling: bool = True config = VideoGenerationConfig() MAX_SEED = np.iinfo(np.int32).max # 글로벌 락 (동시 실행 방지) generation_lock = threading.Lock() # 성능 측정 데코레이터 def measure_time(func): @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) logger.info(f"{func.__name__} took {time.time()-start:.2f}s") return result return wrapper # GPU 메모리 정리 함수 def clear_gpu_memory(): """강력한 GPU 메모리 정리""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() # GPU 메모리 상태 로깅 allocated = torch.cuda.memory_allocated() / 1024**3 reserved = torch.cuda.memory_reserved() / 1024**3 logger.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB") # 모델 관리자 (싱글톤 패턴) class ModelManager: _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not hasattr(self, '_initialized'): self._pipe = None self._is_loaded = False self._initialized = True @property def pipe(self): if not self._is_loaded: self._load_model() return self._pipe @measure_time def _load_model(self): """메모리 효율적인 모델 로딩""" with self._lock: if self._is_loaded: return try: logger.info("Loading model with memory optimizations...") clear_gpu_memory() # 모델 컴포넌트 로드 (메모리 효율적) with torch.cuda.amp.autocast(enabled=False): image_encoder = CLIPVisionModel.from_pretrained( config.model_id, subfolder="image_encoder", torch_dtype=torch.float16, # float32 대신 float16 사용 low_cpu_mem_usage=True ) vae = AutoencoderKLWan.from_pretrained( config.model_id, subfolder="vae", torch_dtype=torch.float16, # float32 대신 float16 사용 low_cpu_mem_usage=True ) self._pipe = WanImageToVideoPipeline.from_pretrained( config.model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_safetensors=True ) # 스케줄러 설정 self._pipe.scheduler = UniPCMultistepScheduler.from_config( self._pipe.scheduler.config, flow_shift=8.0 ) # LoRA 로드 causvid_path = hf_hub_download( repo_id=config.lora_repo_id, filename=config.lora_filename ) self._pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") self._pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) self._pipe.fuse_lora() # GPU 최적화 설정 if config.enable_model_cpu_offload: self._pipe.enable_model_cpu_offload() else: self._pipe.to("cuda") if config.enable_vae_slicing: self._pipe.enable_vae_slicing() if config.enable_vae_tiling: self._pipe.enable_vae_tiling() # xFormers 메모리 효율적인 attention 활성화 (가능한 경우) try: self._pipe.enable_xformers_memory_efficient_attention() logger.info("xFormers memory efficient attention enabled") except: logger.info("xFormers not available, using default attention") self._is_loaded = True logger.info("Model loaded successfully with optimizations") clear_gpu_memory() except Exception as e: logger.error(f"Error loading model: {e}") self._is_loaded = False clear_gpu_memory() raise def unload_model(self): """모델 언로드 및 메모리 해제""" with self._lock: if self._pipe is not None: del self._pipe self._pipe = None self._is_loaded = False clear_gpu_memory() logger.info("Model unloaded and memory cleared") # 싱글톤 인스턴스 model_manager = ModelManager() # 비디오 생성기 클래스 class VideoGenerator: def __init__(self, config: VideoGenerationConfig, model_manager: ModelManager): self.config = config self.model_manager = model_manager def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]: orig_w, orig_h = image.size if orig_w <= 0 or orig_h <= 0: return self.config.default_height, self.config.default_width aspect_ratio = orig_h / orig_w calc_h = round(np.sqrt(self.config.max_area * aspect_ratio)) calc_w = round(np.sqrt(self.config.max_area / aspect_ratio)) calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value) calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value) new_h = int(np.clip(calc_h, self.config.slider_min_h, (self.config.slider_max_h // self.config.mod_value) * self.config.mod_value)) new_w = int(np.clip(calc_w, self.config.slider_min_w, (self.config.slider_max_w // self.config.mod_value) * self.config.mod_value)) return new_h, new_w def validate_inputs(self, image: Image.Image, prompt: str, height: int, width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]: if image is None: return False, "🖼️ Please upload an input image" if not prompt or len(prompt.strip()) == 0: return False, "✍️ Please provide a prompt" if len(prompt) > 500: return False, "⚠️ Prompt is too long (max 500 characters)" if duration < self.config.min_frames / self.config.fixed_fps: return False, f"⏱️ Duration too short (min {self.config.min_frames/self.config.fixed_fps:.1f}s)" if duration > self.config.max_frames / self.config.fixed_fps: return False, f"⏱️ Duration too long (max {self.config.max_frames/self.config.fixed_fps:.1f}s)" # GPU 메모리 체크 if torch.cuda.is_available(): free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() required_memory = (height * width * 3 * 8 * duration * config.fixed_fps) / (1024**3) # 대략적인 추정 if free_memory < required_memory * 2: # 2배 여유 확보 clear_gpu_memory() return False, "⚠️ Not enough GPU memory. Try smaller dimensions or shorter duration." return True, None def generate_unique_filename(self, seed: int) -> str: timestamp = int(time.time()) unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}" hash_obj = hashlib.md5(unique_str.encode()) return f"video_{hash_obj.hexdigest()[:8]}.mp4" video_generator = VideoGenerator(config, model_manager) # Gradio 함수들 def handle_image_upload(image): if image is None: return gr.update(value=config.default_height), gr.update(value=config.default_width) try: if not isinstance(image, Image.Image): raise ValueError("Invalid image format") new_h, new_w = video_generator.calculate_dimensions(image) return gr.update(value=new_h), gr.update(value=new_w) except Exception as e: logger.error(f"Error processing image: {e}") gr.Warning("⚠️ Error processing image") return gr.update(value=config.default_height), gr.update(value=config.default_width) def get_duration(input_image, prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): # GPU 사용량에 따라 동적으로 duration 조정 base_duration = 60 if steps > 4: base_duration += 15 if duration_seconds > 2: base_duration += 15 # 해상도에 따른 추가 시간 pixels = height * width if pixels > 500000: base_duration += 20 return min(base_duration, 120) # 최대 120초 @spaces.GPU(duration=get_duration) @measure_time def generate_video(input_image, prompt, height, width, negative_prompt=config.default_negative_prompt, duration_seconds=2, guidance_scale=1, steps=4, seed=42, randomize_seed=False, progress=gr.Progress(track_tqdm=True)): # 동시 실행 방지 if not generation_lock.acquire(blocking=False): raise gr.Error("⏳ Another video is being generated. Please wait...") try: progress(0.1, desc="🔍 Validating inputs...") # 입력 검증 is_valid, error_msg = video_generator.validate_inputs( input_image, prompt, height, width, duration_seconds, steps ) if not is_valid: raise gr.Error(error_msg) # 메모리 정리 clear_gpu_memory() progress(0.2, desc="🎯 Preparing image...") target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value) target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value) num_frames = np.clip(int(round(duration_seconds * config.fixed_fps)), config.min_frames, config.max_frames) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) # 이미지 리사이즈 (메모리 효율적) resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) progress(0.3, desc="🎨 Loading model...") pipe = model_manager.pipe progress(0.4, desc="🎬 Generating video frames...") # 메모리 효율적인 생성 with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True): try: output_frames_list = pipe( image=resized_image, prompt=prompt, negative_prompt=negative_prompt, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=float(guidance_scale), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed), return_dict=True ).frames[0] except torch.cuda.OutOfMemoryError: clear_gpu_memory() raise gr.Error("💾 GPU out of memory. Try smaller dimensions or shorter duration.") except Exception as e: logger.error(f"Generation error: {e}") raise gr.Error(f"❌ Generation failed: {str(e)}") progress(0.9, desc="💾 Saving video...") filename = video_generator.generate_unique_filename(current_seed) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name export_to_video(output_frames_list, video_path, fps=config.fixed_fps) progress(1.0, desc="✨ Complete!") return video_path, current_seed except Exception as e: logger.error(f"Unexpected error: {e}") raise finally: # 항상 메모리 정리 및 락 해제 generation_lock.release() # 메모리 정리 if 'output_frames_list' in locals(): del output_frames_list if 'resized_image' in locals(): del resized_image clear_gpu_memory() # 개선된 CSS 스타일 css = """ .container { max-width: 1200px; margin: auto; padding: 20px; } .header { text-align: center; margin-bottom: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 40px; border-radius: 20px; color: white; box-shadow: 0 10px 30px rgba(0,0,0,0.2); position: relative; overflow: hidden; } .header::before { content: ''; position: absolute; top: -50%; left: -50%; width: 200%; height: 200%; background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%); animation: pulse 4s ease-in-out infinite; } @keyframes pulse { 0%, 100% { transform: scale(1); opacity: 0.5; } 50% { transform: scale(1.1); opacity: 0.8; } } .header h1 { font-size: 3em; margin-bottom: 10px; text-shadow: 2px 2px 4px rgba(0,0,0,0.3); position: relative; z-index: 1; } .header p { font-size: 1.2em; opacity: 0.95; position: relative; z-index: 1; } .gpu-status { position: absolute; top: 10px; right: 10px; background: rgba(0,0,0,0.3); padding: 5px 15px; border-radius: 20px; font-size: 0.8em; } .main-content { background: rgba(255, 255, 255, 0.95); border-radius: 20px; padding: 30px; box-shadow: 0 5px 20px rgba(0,0,0,0.1); backdrop-filter: blur(10px); } .input-section { background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); padding: 25px; border-radius: 15px; margin-bottom: 20px; } .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; font-size: 1.3em; padding: 15px 40px; border-radius: 30px; border: none; cursor: pointer; transition: all 0.3s ease; box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); width: 100%; margin-top: 20px; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6); } .generate-btn:active { transform: translateY(0); } .video-output { background: #f8f9fa; padding: 20px; border-radius: 15px; text-align: center; min-height: 400px; display: flex; align-items: center; justify-content: center; } .accordion { background: rgba(255, 255, 255, 0.7); border-radius: 10px; margin-top: 15px; padding: 15px; } .slider-container { background: rgba(255, 255, 255, 0.5); padding: 15px; border-radius: 10px; margin: 10px 0; } body { background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); background-size: 400% 400%; animation: gradient 15s ease infinite; } @keyframes gradient { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } .warning-box { background: rgba(255, 193, 7, 0.1); border: 1px solid rgba(255, 193, 7, 0.3); border-radius: 10px; padding: 15px; margin: 10px 0; color: #856404; font-size: 0.9em; } .footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # Gradio UI with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_classes="container"): # Header with GPU status gr.HTML("""
Transform your images into captivating videos with Wan 2.1 + CausVid LoRA
Enhanced with: 🛡️ GPU Crash Protection • ⚡ Memory Optimization • 🎨 Modern UI • 🔧 Clean Architecture