seawolf2357 commited on
Commit
ec4cebf
·
verified ·
1 Parent(s): 5c774ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +448 -130
app.py CHANGED
@@ -9,81 +9,162 @@ from huggingface_hub import hf_hub_download
9
  import numpy as np
10
  from PIL import Image
11
  import random
 
 
 
 
 
 
 
12
 
13
- MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
14
- LORA_REPO_ID = "Kijai/WanVideo_comfy"
15
- LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
16
-
17
- image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
18
- vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
19
- pipe = WanImageToVideoPipeline.from_pretrained(
20
- MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
21
- )
22
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
23
- pipe.to("cuda")
24
-
25
- causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
26
- pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
27
- pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
28
- pipe.fuse_lora()
29
-
30
- MOD_VALUE = 32
31
- DEFAULT_H_SLIDER_VALUE = 512
32
- DEFAULT_W_SLIDER_VALUE = 896
33
- NEW_FORMULA_MAX_AREA = 480.0 * 832.0
34
-
35
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
36
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
37
- MAX_SEED = np.iinfo(np.int32).max
38
-
39
- FIXED_FPS = 24
40
- MIN_FRAMES_MODEL = 8
41
- MAX_FRAMES_MODEL = 81
42
 
43
- default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
44
- default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
46
 
47
- def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
48
- min_slider_h, max_slider_h,
49
- min_slider_w, max_slider_w,
50
- default_h, default_w):
51
- orig_w, orig_h = pil_image.size
52
- if orig_w <= 0 or orig_h <= 0:
53
- return default_h, default_w
 
 
54
 
55
- aspect_ratio = orig_h / orig_w
 
 
 
 
 
 
 
 
 
 
56
 
57
- calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
58
- calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
61
- calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
 
 
 
62
 
63
- new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
64
- new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- return new_h, new_w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
69
- if uploaded_pil_image is None:
70
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
 
 
71
  try:
72
- new_h, new_w = _calculate_new_dimensions_wan(
73
- uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
74
- SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
75
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
76
- )
77
  return gr.update(value=new_h), gr.update(value=new_w)
 
78
  except Exception as e:
79
- gr.Warning("Error attempting to calculate new dimensions")
80
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
81
-
82
- def get_duration(input_image, prompt, height, width,
83
- negative_prompt, duration_seconds,
84
- guidance_scale, steps,
85
- seed, randomize_seed,
86
- progress):
87
  if steps > 4 and duration_seconds > 2:
88
  return 90
89
  elif steps > 4 or duration_seconds > 2:
@@ -92,85 +173,322 @@ def get_duration(input_image, prompt, height, width,
92
  return 60
93
 
94
  @spaces.GPU(duration=get_duration)
 
95
  def generate_video(input_image, prompt, height, width,
96
- negative_prompt=default_negative_prompt, duration_seconds = 2,
97
- guidance_scale = 1, steps = 4,
98
- seed = 42, randomize_seed = False,
99
  progress=gr.Progress(track_tqdm=True)):
100
 
101
- if input_image is None:
102
- raise gr.Error("Please upload an input image.")
103
-
104
- target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
105
- target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
106
 
107
- num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
 
 
 
 
 
108
 
109
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
110
-
111
- resized_image = input_image.resize((target_w, target_h))
112
-
113
- with torch.inference_mode():
114
- output_frames_list = pipe(
115
- image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
116
- height=target_h, width=target_w, num_frames=num_frames,
117
- guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
118
- generator=torch.Generator(device="cuda").manual_seed(current_seed)
119
- ).frames[0]
120
-
121
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
122
- video_path = tmpfile.name
123
- export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
124
- return video_path, current_seed
125
-
126
- with gr.Blocks() as demo:
127
- gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA")
128
- gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
129
- with gr.Row():
130
- with gr.Column():
131
- input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
132
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
133
- duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- with gr.Accordion("Advanced Settings", open=False):
136
- negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
137
- seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
138
- randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
139
- with gr.Row():
140
- height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
141
- width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
142
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
143
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
144
-
145
- generate_button = gr.Button("Generate Video", variant="primary")
146
- with gr.Column():
147
- video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
148
-
149
- input_image_component.upload(
150
- fn=handle_image_upload_for_dims_wan,
151
- inputs=[input_image_component, height_input, width_input],
152
- outputs=[height_input, width_input]
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  )
154
 
155
- input_image_component.clear(
156
- fn=handle_image_upload_for_dims_wan,
157
- inputs=[input_image_component, height_input, width_input],
158
- outputs=[height_input, width_input]
159
  )
160
 
161
- ui_inputs = [
162
- input_image_component, prompt_input, height_input, width_input,
163
- negative_prompt_input, duration_seconds_input,
164
- guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
165
- ]
166
- generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
167
-
168
- gr.Examples(
169
- examples=[
170
- ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
171
- ["forg.jpg", "the frog jumps around", 448, 832],
172
  ],
173
- inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
174
  )
175
 
176
  if __name__ == "__main__":
 
9
  import numpy as np
10
  from PIL import Image
11
  import random
12
+ import logging
13
+ import gc
14
+ import time
15
+ import hashlib
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple
18
+ from functools import wraps
19
 
20
+ # 로깅 설정
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # 설정 관리
25
+ @dataclass
26
+ class VideoGenerationConfig:
27
+ model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
28
+ lora_repo_id: str = "Kijai/WanVideo_comfy"
29
+ lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
30
+ mod_value: int = 32
31
+ default_height: int = 512
32
+ default_width: int = 896
33
+ max_area: float = 480.0 * 832.0
34
+ slider_min_h: int = 128
35
+ slider_max_h: int = 896
36
+ slider_min_w: int = 128
37
+ slider_max_w: int = 896
38
+ fixed_fps: int = 24
39
+ min_frames: int = 8
40
+ max_frames: int = 81
41
+ default_prompt: str = "make this image come alive, cinematic motion, smooth animation"
42
+ default_negative_prompt: str = "static, blurred, low quality, watermark, text"
43
 
44
+ config = VideoGenerationConfig()
45
+ MAX_SEED = np.iinfo(np.int32).max
46
 
47
+ # 성능 측정 데코레이터
48
+ def measure_time(func):
49
+ @wraps(func)
50
+ def wrapper(*args, **kwargs):
51
+ start = time.time()
52
+ result = func(*args, **kwargs)
53
+ logger.info(f"{func.__name__} took {time.time()-start:.2f}s")
54
+ return result
55
+ return wrapper
56
 
57
+ # 모델 관리자
58
+ class ModelManager:
59
+ def __init__(self):
60
+ self._pipe = None
61
+ self._is_loaded = False
62
+
63
+ @property
64
+ def pipe(self):
65
+ if not self._is_loaded:
66
+ self._load_model()
67
+ return self._pipe
68
 
69
+ @measure_time
70
+ def _load_model(self):
71
+ logger.info("Loading model...")
72
+ image_encoder = CLIPVisionModel.from_pretrained(
73
+ config.model_id, subfolder="image_encoder", torch_dtype=torch.float32
74
+ )
75
+ vae = AutoencoderKLWan.from_pretrained(
76
+ config.model_id, subfolder="vae", torch_dtype=torch.float32
77
+ )
78
+ self._pipe = WanImageToVideoPipeline.from_pretrained(
79
+ config.model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
80
+ )
81
+ self._pipe.scheduler = UniPCMultistepScheduler.from_config(
82
+ self._pipe.scheduler.config, flow_shift=8.0
83
+ )
84
+ self._pipe.to("cuda")
85
+
86
+ causvid_path = hf_hub_download(
87
+ repo_id=config.lora_repo_id, filename=config.lora_filename
88
+ )
89
+ self._pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
90
+ self._pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
91
+ self._pipe.fuse_lora()
92
+ self._is_loaded = True
93
+ logger.info("Model loaded successfully")
94
+
95
+ model_manager = ModelManager()
96
 
97
+ # 비디오 생성기 클래스
98
+ class VideoGenerator:
99
+ def __init__(self, config: VideoGenerationConfig, model_manager: ModelManager):
100
+ self.config = config
101
+ self.model_manager = model_manager
102
 
103
+ def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]:
104
+ orig_w, orig_h = image.size
105
+ if orig_w <= 0 or orig_h <= 0:
106
+ return self.config.default_height, self.config.default_width
107
+
108
+ aspect_ratio = orig_h / orig_w
109
+ calc_h = round(np.sqrt(self.config.max_area * aspect_ratio))
110
+ calc_w = round(np.sqrt(self.config.max_area / aspect_ratio))
111
+
112
+ calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value)
113
+ calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value)
114
+
115
+ new_h = int(np.clip(calc_h, self.config.slider_min_h,
116
+ (self.config.slider_max_h // self.config.mod_value) * self.config.mod_value))
117
+ new_w = int(np.clip(calc_w, self.config.slider_min_w,
118
+ (self.config.slider_max_w // self.config.mod_value) * self.config.mod_value))
119
+
120
+ return new_h, new_w
121
 
122
+ def validate_inputs(self, image: Image.Image, prompt: str, height: int,
123
+ width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]:
124
+ if image is None:
125
+ return False, "🖼️ Please upload an input image"
126
+
127
+ if not prompt or len(prompt.strip()) == 0:
128
+ return False, "✍️ Please provide a prompt"
129
+
130
+ if len(prompt) > 500:
131
+ return False, "⚠️ Prompt is too long (max 500 characters)"
132
+
133
+ if duration < self.config.min_frames / self.config.fixed_fps:
134
+ return False, f"⏱️ Duration too short (min {self.config.min_frames/self.config.fixed_fps:.1f}s)"
135
+
136
+ if duration > self.config.max_frames / self.config.fixed_fps:
137
+ return False, f"⏱️ Duration too long (max {self.config.max_frames/self.config.fixed_fps:.1f}s)"
138
+
139
+ return True, None
140
+
141
+ def generate_unique_filename(self, seed: int) -> str:
142
+ timestamp = int(time.time())
143
+ unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}"
144
+ hash_obj = hashlib.md5(unique_str.encode())
145
+ return f"video_{hash_obj.hexdigest()[:8]}.mp4"
146
+
147
+ video_generator = VideoGenerator(config, model_manager)
148
 
149
+ # Gradio 함수들
150
+ def handle_image_upload(image):
151
+ if image is None:
152
+ return gr.update(value=config.default_height), gr.update(value=config.default_width)
153
+
154
  try:
155
+ if not isinstance(image, Image.Image):
156
+ raise ValueError("Invalid image format")
157
+
158
+ new_h, new_w = video_generator.calculate_dimensions(image)
 
159
  return gr.update(value=new_h), gr.update(value=new_w)
160
+
161
  except Exception as e:
162
+ logger.error(f"Error processing image: {e}")
163
+ gr.Warning("⚠️ Error processing image")
164
+ return gr.update(value=config.default_height), gr.update(value=config.default_width)
165
+
166
+ def get_duration(input_image, prompt, height, width, negative_prompt,
167
+ duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
 
 
168
  if steps > 4 and duration_seconds > 2:
169
  return 90
170
  elif steps > 4 or duration_seconds > 2:
 
173
  return 60
174
 
175
  @spaces.GPU(duration=get_duration)
176
+ @measure_time
177
  def generate_video(input_image, prompt, height, width,
178
+ negative_prompt=config.default_negative_prompt,
179
+ duration_seconds=2, guidance_scale=1, steps=4,
180
+ seed=42, randomize_seed=False,
181
  progress=gr.Progress(track_tqdm=True)):
182
 
183
+ progress(0.1, desc="🔍 Validating inputs...")
 
 
 
 
184
 
185
+ # 입력 검증
186
+ is_valid, error_msg = video_generator.validate_inputs(
187
+ input_image, prompt, height, width, duration_seconds, steps
188
+ )
189
+ if not is_valid:
190
+ raise gr.Error(error_msg)
191
 
192
+ try:
193
+ progress(0.2, desc="🎯 Preparing image...")
194
+ target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value)
195
+ target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value)
196
+ num_frames = np.clip(int(round(duration_seconds * config.fixed_fps)),
197
+ config.min_frames, config.max_frames)
198
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
199
+
200
+ resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
201
+
202
+ progress(0.3, desc="🎨 Loading model...")
203
+ pipe = model_manager.pipe
204
+
205
+ progress(0.4, desc="🎬 Generating video frames...")
206
+ with torch.inference_mode():
207
+ output_frames_list = pipe(
208
+ image=resized_image,
209
+ prompt=prompt,
210
+ negative_prompt=negative_prompt,
211
+ height=target_h,
212
+ width=target_w,
213
+ num_frames=num_frames,
214
+ guidance_scale=float(guidance_scale),
215
+ num_inference_steps=int(steps),
216
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
217
+ ).frames[0]
218
+
219
+ progress(0.9, desc="💾 Saving video...")
220
+ filename = video_generator.generate_unique_filename(current_seed)
221
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
222
+ video_path = tmpfile.name
223
+
224
+ export_to_video(output_frames_list, video_path, fps=config.fixed_fps)
225
+
226
+ progress(1.0, desc="✨ Complete!")
227
+ return video_path, current_seed
228
+
229
+ finally:
230
+ # 메모리 정리
231
+ if 'output_frames_list' in locals():
232
+ del output_frames_list
233
+ gc.collect()
234
+ torch.cuda.empty_cache()
235
+
236
+ # CSS 스타일
237
+ css = """
238
+ .container {
239
+ max-width: 1200px;
240
+ margin: auto;
241
+ padding: 20px;
242
+ }
243
+
244
+ .header {
245
+ text-align: center;
246
+ margin-bottom: 30px;
247
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
248
+ padding: 40px;
249
+ border-radius: 20px;
250
+ color: white;
251
+ box-shadow: 0 10px 30px rgba(0,0,0,0.2);
252
+ }
253
+
254
+ .header h1 {
255
+ font-size: 3em;
256
+ margin-bottom: 10px;
257
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
258
+ }
259
+
260
+ .header p {
261
+ font-size: 1.2em;
262
+ opacity: 0.95;
263
+ }
264
+
265
+ .main-content {
266
+ background: rgba(255, 255, 255, 0.95);
267
+ border-radius: 20px;
268
+ padding: 30px;
269
+ box-shadow: 0 5px 20px rgba(0,0,0,0.1);
270
+ backdrop-filter: blur(10px);
271
+ }
272
+
273
+ .input-section {
274
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
275
+ padding: 25px;
276
+ border-radius: 15px;
277
+ margin-bottom: 20px;
278
+ }
279
+
280
+ .generate-btn {
281
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
282
+ color: white;
283
+ font-size: 1.3em;
284
+ padding: 15px 40px;
285
+ border-radius: 30px;
286
+ border: none;
287
+ cursor: pointer;
288
+ transition: all 0.3s ease;
289
+ box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
290
+ width: 100%;
291
+ margin-top: 20px;
292
+ }
293
+
294
+ .generate-btn:hover {
295
+ transform: translateY(-2px);
296
+ box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6);
297
+ }
298
+
299
+ .video-output {
300
+ background: #f8f9fa;
301
+ padding: 20px;
302
+ border-radius: 15px;
303
+ text-align: center;
304
+ min-height: 400px;
305
+ display: flex;
306
+ align-items: center;
307
+ justify-content: center;
308
+ }
309
+
310
+ .accordion {
311
+ background: rgba(255, 255, 255, 0.7);
312
+ border-radius: 10px;
313
+ margin-top: 15px;
314
+ padding: 15px;
315
+ }
316
+
317
+ .slider-container {
318
+ background: rgba(255, 255, 255, 0.5);
319
+ padding: 15px;
320
+ border-radius: 10px;
321
+ margin: 10px 0;
322
+ }
323
+
324
+ body {
325
+ background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
326
+ background-size: 400% 400%;
327
+ animation: gradient 15s ease infinite;
328
+ }
329
+
330
+ @keyframes gradient {
331
+ 0% { background-position: 0% 50%; }
332
+ 50% { background-position: 100% 50%; }
333
+ 100% { background-position: 0% 50%; }
334
+ }
335
+
336
+ .gr-button-secondary {
337
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
338
+ }
339
+
340
+ .footer {
341
+ text-align: center;
342
+ margin-top: 30px;
343
+ color: #666;
344
+ font-size: 0.9em;
345
+ }
346
+ """
347
+
348
+ # Gradio UI
349
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
350
+ with gr.Column(elem_classes="container"):
351
+ # Header
352
+ gr.HTML("""
353
+ <div class="header">
354
+ <h1>🎬 AI Video Magic Studio</h1>
355
+ <p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p>
356
+ </div>
357
+ """)
358
+
359
+ with gr.Row(elem_classes="main-content"):
360
+ with gr.Column(scale=1):
361
+ gr.Markdown("### 📸 Input Settings")
362
+
363
+ with gr.Column(elem_classes="input-section"):
364
+ input_image = gr.Image(
365
+ type="pil",
366
+ label="🖼️ Upload Your Image",
367
+ elem_classes="image-upload"
368
+ )
369
+
370
+ prompt_input = gr.Textbox(
371
+ label="✨ Animation Prompt",
372
+ value=config.default_prompt,
373
+ placeholder="Describe how you want your image to move...",
374
+ lines=2
375
+ )
376
+
377
+ duration_input = gr.Slider(
378
+ minimum=round(config.min_frames/config.fixed_fps, 1),
379
+ maximum=round(config.max_frames/config.fixed_fps, 1),
380
+ step=0.1,
381
+ value=2,
382
+ label="⏱️ Video Duration (seconds)",
383
+ elem_classes="slider-container"
384
+ )
385
+
386
+ with gr.Accordion("🎛️ Advanced Settings", open=False, elem_classes="accordion"):
387
+ negative_prompt = gr.Textbox(
388
+ label="🚫 Negative Prompt",
389
+ value=config.default_negative_prompt,
390
+ lines=2
391
+ )
392
+
393
+ with gr.Row():
394
+ seed = gr.Slider(
395
+ minimum=0,
396
+ maximum=MAX_SEED,
397
+ step=1,
398
+ value=42,
399
+ label="🎲 Seed"
400
+ )
401
+ randomize_seed = gr.Checkbox(
402
+ label="🔀 Randomize",
403
+ value=True
404
+ )
405
+
406
+ with gr.Row():
407
+ height_slider = gr.Slider(
408
+ minimum=config.slider_min_h,
409
+ maximum=config.slider_max_h,
410
+ step=config.mod_value,
411
+ value=config.default_height,
412
+ label="📏 Height"
413
+ )
414
+ width_slider = gr.Slider(
415
+ minimum=config.slider_min_w,
416
+ maximum=config.slider_max_w,
417
+ step=config.mod_value,
418
+ value=config.default_width,
419
+ label="📐 Width"
420
+ )
421
+
422
+ steps_slider = gr.Slider(
423
+ minimum=1,
424
+ maximum=30,
425
+ step=1,
426
+ value=4,
427
+ label="🔧 Quality Steps (4-8 recommended)"
428
+ )
429
+
430
+ guidance_scale = gr.Slider(
431
+ minimum=0.0,
432
+ maximum=20.0,
433
+ step=0.5,
434
+ value=1.0,
435
+ label="🎯 Guidance Scale",
436
+ visible=False
437
+ )
438
+
439
+ generate_btn = gr.Button(
440
+ "🎬 Generate Video",
441
+ variant="primary",
442
+ elem_classes="generate-btn"
443
+ )
444
 
445
+ with gr.Column(scale=1):
446
+ gr.Markdown("### 🎥 Generated Video")
447
+ video_output = gr.Video(
448
+ label="",
449
+ autoplay=True,
450
+ elem_classes="video-output"
451
+ )
452
+
453
+ gr.HTML("""
454
+ <div class="footer">
455
+ <p>💡 Tip: For best results, use clear images with good lighting</p>
456
+ </div>
457
+ """)
458
+
459
+ # Examples
460
+ gr.Examples(
461
+ examples=[
462
+ ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
463
+ ["forg.jpg", "the frog jumps around", 448, 832],
464
+ ],
465
+ inputs=[input_image, prompt_input, height_slider, width_slider],
466
+ outputs=[video_output, seed],
467
+ fn=generate_video,
468
+ cache_examples="lazy"
469
+ )
470
+
471
+ # Event handlers
472
+ input_image.upload(
473
+ fn=handle_image_upload,
474
+ inputs=[input_image],
475
+ outputs=[height_slider, width_slider]
476
  )
477
 
478
+ input_image.clear(
479
+ fn=handle_image_upload,
480
+ inputs=[input_image],
481
+ outputs=[height_slider, width_slider]
482
  )
483
 
484
+ generate_btn.click(
485
+ fn=generate_video,
486
+ inputs=[
487
+ input_image, prompt_input, height_slider, width_slider,
488
+ negative_prompt, duration_input, guidance_scale,
489
+ steps_slider, seed, randomize_seed
 
 
 
 
 
490
  ],
491
+ outputs=[video_output, seed]
492
  )
493
 
494
  if __name__ == "__main__":