root commited on
Commit
4e212b7
·
1 Parent(s): feeedf2
Files changed (1) hide show
  1. app.py +22 -32
app.py CHANGED
@@ -6,8 +6,6 @@ import random
6
  import numpy as np
7
  import os
8
  import gc
9
- import tempfile
10
- import imageio
11
  from diffusers import AutoencoderKLWan
12
  from wan_pipeline import WanPipeline
13
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
@@ -18,7 +16,6 @@ from huggingface_hub import login
18
  # Authenticate with HF
19
  login(token=os.getenv('HF_TOKEN'))
20
 
21
- # Set seed
22
  def set_seed(seed):
23
  random.seed(seed)
24
  os.environ['PYTHONHASHSEED'] = str(seed)
@@ -26,7 +23,6 @@ def set_seed(seed):
26
  torch.manual_seed(seed)
27
  torch.cuda.manual_seed(seed)
28
 
29
- # Model paths
30
  model_paths = {
31
  "sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
32
  "sd3.5": "stabilityai/stable-diffusion-3.5-large",
@@ -55,7 +51,10 @@ def load_model(model_name):
55
  return current_model.to("cuda")
56
 
57
  @spaces.GPU(duration=500)
58
- def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False):
 
 
 
59
  model = load_model(model_name)
60
  if seed is None:
61
  seed = random.randint(0, 2**32 - 1)
@@ -82,8 +81,6 @@ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps
82
 
83
  return None, None, video1_path, seed
84
 
85
- print("prompt:", prompt)
86
-
87
  if compare_mode:
88
  set_seed(seed)
89
  image1 = model(
@@ -121,7 +118,7 @@ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps
121
  else:
122
  return None, image, None, seed
123
 
124
- # Gradio UI
125
  with gr.Blocks() as demo:
126
  gr.HTML("""
127
  <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
@@ -134,31 +131,24 @@ with gr.Blocks() as demo:
134
  """)
135
 
136
  with gr.Row():
137
- prompt = gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt")
138
- model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model")
139
-
140
- with gr.Row():
141
- guidance_scale = gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale")
142
- inference_steps = gr.Slider(10, 100, value=28, step=5, label="Inference Steps")
143
-
144
- with gr.Row():
145
- use_opt_scale = gr.Checkbox(value=True, label="Use Optimized-Scale")
146
- use_zero_init = gr.Checkbox(value=True, label="Use Zero Init")
147
- zero_steps = gr.Slider(0, 20, value=0, step=1, label="Zero out steps")
148
-
149
- with gr.Row():
150
- seed = gr.Number(value=42, label="Seed (Leave blank for random)")
151
- compare_mode = gr.Checkbox(value=True, label="Compare Mode")
152
-
153
- with gr.Row():
154
- out1 = gr.Image(type="pil", label="CFG-Zero* Image")
155
- out2 = gr.Image(type="pil", label="CFG Image")
156
- video = gr.Video(label="Video")
157
- used_seed = gr.Textbox(label="Used Seed")
158
-
159
- generate_btn = gr.Button("Generate")
160
 
161
- # Change logic for when "wan-t2v" is selected
162
  def update_params(model_name):
163
  if model_name == "wan-t2v":
164
  return (
 
6
  import numpy as np
7
  import os
8
  import gc
 
 
9
  from diffusers import AutoencoderKLWan
10
  from wan_pipeline import WanPipeline
11
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
 
16
  # Authenticate with HF
17
  login(token=os.getenv('HF_TOKEN'))
18
 
 
19
  def set_seed(seed):
20
  random.seed(seed)
21
  os.environ['PYTHONHASHSEED'] = str(seed)
 
23
  torch.manual_seed(seed)
24
  torch.cuda.manual_seed(seed)
25
 
 
26
  model_paths = {
27
  "sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
28
  "sd3.5": "stabilityai/stable-diffusion-3.5-large",
 
51
  return current_model.to("cuda")
52
 
53
  @spaces.GPU(duration=500)
54
+ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50,
55
+ use_cfg_zero_star=True, use_zero_init=True, zero_steps=0,
56
+ seed=None, compare_mode=False):
57
+
58
  model = load_model(model_name)
59
  if seed is None:
60
  seed = random.randint(0, 2**32 - 1)
 
81
 
82
  return None, None, video1_path, seed
83
 
 
 
84
  if compare_mode:
85
  set_seed(seed)
86
  image1 = model(
 
118
  else:
119
  return None, image, None, seed
120
 
121
+ # Gradio UI with left-right layout
122
  with gr.Blocks() as demo:
123
  gr.HTML("""
124
  <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
 
131
  """)
132
 
133
  with gr.Row():
134
+ with gr.Column(scale=1):
135
+ prompt = gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt")
136
+ model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model")
137
+ guidance_scale = gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale")
138
+ inference_steps = gr.Slider(10, 100, value=28, step=5, label="Inference Steps")
139
+ use_opt_scale = gr.Checkbox(value=True, label="Use Optimized-Scale")
140
+ use_zero_init = gr.Checkbox(value=True, label="Use Zero Init")
141
+ zero_steps = gr.Slider(0, 20, value=0, step=1, label="Zero out steps")
142
+ seed = gr.Number(value=42, label="Seed (Leave blank for random)")
143
+ compare_mode = gr.Checkbox(value=True, label="Compare Mode")
144
+ generate_btn = gr.Button("Generate")
145
+
146
+ with gr.Column(scale=2):
147
+ out1 = gr.Image(type="pil", label="CFG-Zero* Image")
148
+ out2 = gr.Image(type="pil", label="CFG Image")
149
+ video = gr.Video(label="Video")
150
+ used_seed = gr.Textbox(label="Used Seed")
 
 
 
 
 
 
151
 
 
152
  def update_params(model_name):
153
  if model_name == "wan-t2v":
154
  return (