import torch from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video from transformers import CLIPVisionModel import gradio as gr import tempfile import spaces from huggingface_hub import hf_hub_download import numpy as np from PIL import Image import random import logging import gc # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 모델 설정 MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" LORA_REPO_ID = "Kijai/WanVideo_comfy" LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" # 파라미터 설정 MOD_VALUE = 32 DEFAULT_H_SLIDER_VALUE = 512 DEFAULT_W_SLIDER_VALUE = 512 # Zero GPU를 위해 정사각형 기본값 NEW_FORMULA_MAX_AREA = 480.0 * 832.0 SLIDER_MIN_H, SLIDER_MAX_H = 128, 896 SLIDER_MIN_W, SLIDER_MAX_W = 128, 896 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 24 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 81 default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" default_negative_prompt = "static, blurred, low quality, watermark, text" # 모델 글로벌 로딩 logger.info("Loading model components...") image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) pipe.to("cuda") # LoRA 로딩 try: causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) pipe.fuse_lora() logger.info("LoRA loaded successfully") except Exception as e: logger.warning(f"LoRA loading failed: {e}") # 메모리 최적화 - WanImageToVideoPipeline에서 지원하는 메서드만 사용 try: pipe.enable_model_cpu_offload() logger.info("CPU offload enabled") except: logger.info("CPU offload not available") logger.info("Model loaded and ready") def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, min_slider_h, max_slider_h, min_slider_w, max_slider_w, default_h, default_w): orig_w, orig_h = pil_image.size if orig_w <= 0 or orig_h <= 0: return default_h, default_w aspect_ratio = orig_h / orig_w # Zero GPU를 위한 보수적인 계산 if hasattr(spaces, 'GPU'): # 더 작은 max_area 사용 calculation_max_area = min(calculation_max_area, 320.0 * 320.0) calc_h = round(np.sqrt(calculation_max_area * aspect_ratio)) calc_w = round(np.sqrt(calculation_max_area / aspect_ratio)) calc_h = max(mod_val, (calc_h // mod_val) * mod_val) calc_w = max(mod_val, (calc_w // mod_val) * mod_val) # Zero GPU 환경에서 추가 제한 if hasattr(spaces, 'GPU'): max_slider_h = min(max_slider_h, 640) max_slider_w = min(max_slider_w, 640) new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val)) new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val)) return new_h, new_w def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val): if uploaded_pil_image is None: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) try: new_h, new_w = _calculate_new_dimensions_wan( uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA, SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W, DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE ) return gr.update(value=new_h), gr.update(value=new_w) except Exception as e: gr.Warning("Error attempting to calculate new dimensions") return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) def get_duration(input_image, prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): # Zero GPU를 위한 보수적인 시간 할당 base_time = 60 if hasattr(spaces, 'GPU'): # Zero GPU 환경에서 더 많은 시간 할당 if steps > 4 and duration_seconds > 2: return 90 elif steps > 4 or duration_seconds > 2: return 80 else: return 70 else: # 일반 GPU 환경 if steps > 4 and duration_seconds > 2: return 90 elif steps > 4 or duration_seconds > 2: return 75 else: return 60 @spaces.GPU(duration=get_duration) def generate_video(input_image, prompt, height, width, negative_prompt=default_negative_prompt, duration_seconds = 2, guidance_scale = 1, steps = 4, seed = 42, randomize_seed = False, progress=gr.Progress(track_tqdm=True)): if input_image is None: raise gr.Error("Please upload an input image.") # Zero GPU 환경에서 추가 검증 if hasattr(spaces, 'GPU'): # 픽셀 제한 max_pixels = 409600 # 640x640 if height * width > max_pixels: raise gr.Error(f"Resolution too high for Zero GPU. Maximum {max_pixels:,} pixels (e.g., 640×640)") # Duration 제한 if duration_seconds > 2.5: duration_seconds = 2.5 gr.Warning("Duration limited to 2.5s in Zero GPU environment") # Steps 제한 if steps > 8: steps = 8 gr.Warning("Steps limited to 8 in Zero GPU environment") target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) # Zero GPU에서 프레임 수 추가 제한 if hasattr(spaces, 'GPU'): max_frames_zerogpu = int(2.5 * FIXED_FPS) # 2.5초 num_frames = min(num_frames, max_frames_zerogpu) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) logger.info(f"Generating video: {target_h}x{target_w}, {num_frames} frames, seed={current_seed}") # 이미지 리사이즈 resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) try: with torch.inference_mode(): output_frames_list = pipe( image=resized_image, prompt=prompt, negative_prompt=negative_prompt, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=float(guidance_scale), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed) ).frames[0] except torch.cuda.OutOfMemoryError: gc.collect() torch.cuda.empty_cache() raise gr.Error("GPU out of memory. Try smaller resolution or shorter duration.") except Exception as e: logger.error(f"Generation failed: {e}") raise gr.Error(f"Video generation failed: {str(e)[:100]}") # 비디오 저장 with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name export_to_video(output_frames_list, video_path, fps=FIXED_FPS) # 메모리 정리 del output_frames_list gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return video_path, current_seed # CSS 스타일 (기존 UI 유지) css = """ .container { max-width: 1200px; margin: auto; padding: 20px; } .header { text-align: center; margin-bottom: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 40px; border-radius: 20px; color: white; box-shadow: 0 10px 30px rgba(0,0,0,0.2); position: relative; overflow: hidden; } .header::before { content: ''; position: absolute; top: -50%; left: -50%; width: 200%; height: 200%; background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%); animation: pulse 4s ease-in-out infinite; } @keyframes pulse { 0%, 100% { transform: scale(1); opacity: 0.5; } 50% { transform: scale(1.1); opacity: 0.8; } } .header h1 { font-size: 3em; margin-bottom: 10px; text-shadow: 2px 2px 4px rgba(0,0,0,0.3); position: relative; z-index: 1; } .header p { font-size: 1.2em; opacity: 0.95; position: relative; z-index: 1; } .gpu-status { position: absolute; top: 10px; right: 10px; background: rgba(0,0,0,0.3); padding: 5px 15px; border-radius: 20px; font-size: 0.8em; } .main-content { background: rgba(255, 255, 255, 0.95); border-radius: 20px; padding: 30px; box-shadow: 0 5px 20px rgba(0,0,0,0.1); backdrop-filter: blur(10px); } .input-section { background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); padding: 25px; border-radius: 15px; margin-bottom: 20px; } .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; font-size: 1.3em; padding: 15px 40px; border-radius: 30px; border: none; cursor: pointer; transition: all 0.3s ease; box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); width: 100%; margin-top: 20px; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6); } .generate-btn:active { transform: translateY(0); } .video-output { background: #f8f9fa; padding: 20px; border-radius: 15px; text-align: center; min-height: 400px; display: flex; align-items: center; justify-content: center; } .accordion { background: rgba(255, 255, 255, 0.7); border-radius: 10px; margin-top: 15px; padding: 15px; } .slider-container { background: rgba(255, 255, 255, 0.5); padding: 15px; border-radius: 10px; margin: 10px 0; } body { background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); background-size: 400% 400%; animation: gradient 15s ease infinite; } @keyframes gradient { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } .warning-box { background: rgba(255, 193, 7, 0.1); border: 1px solid rgba(255, 193, 7, 0.3); border-radius: 10px; padding: 15px; margin: 10px 0; color: #856404; font-size: 0.9em; } .info-box { background: rgba(52, 152, 219, 0.1); border: 1px solid rgba(52, 152, 219, 0.3); border-radius: 10px; padding: 15px; margin: 10px 0; color: #2c5282; font-size: 0.9em; } .footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # Gradio UI (기존 구조 유지) with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_classes="container"): # Header with GPU status gr.HTML("""
Transform your images into captivating videos with Wan 2.1 + CausVid LoRA
Powered by: Wan 2.1 I2V (14B) + CausVid LoRA • 🚀 4-8 steps fast inference • 🎬 Up to 81 frames