Spaces:
Running
Running
Update generate.py
Browse files- generate.py +48 -28
generate.py
CHANGED
@@ -1,8 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def generate(args):
|
2 |
rank = int(os.getenv("RANK", 0))
|
3 |
world_size = int(os.getenv("WORLD_SIZE", 1))
|
4 |
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
5 |
-
|
6 |
# Set device: use CPU if specified, else use GPU based on rank
|
7 |
if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
|
8 |
device = torch.device("cpu")
|
@@ -100,16 +108,20 @@ def generate(args):
|
|
100 |
)
|
101 |
|
102 |
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
|
114 |
else: # image-to-video
|
115 |
if args.prompt is None:
|
@@ -153,17 +165,21 @@ def generate(args):
|
|
153 |
)
|
154 |
|
155 |
logging.info("Generating video ...")
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
167 |
|
168 |
# Save the output video or image
|
169 |
if rank == 0:
|
@@ -173,9 +189,13 @@ def generate(args):
|
|
173 |
suffix = '.png' if "t2i" in args.task else '.mp4'
|
174 |
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from PIL import Image
|
6 |
+
from datetime import datetime
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
def generate(args):
|
10 |
rank = int(os.getenv("RANK", 0))
|
11 |
world_size = int(os.getenv("WORLD_SIZE", 1))
|
12 |
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
13 |
+
|
14 |
# Set device: use CPU if specified, else use GPU based on rank
|
15 |
if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
|
16 |
device = torch.device("cpu")
|
|
|
108 |
)
|
109 |
|
110 |
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
111 |
+
try:
|
112 |
+
video = wan_t2v.generate(
|
113 |
+
args.prompt,
|
114 |
+
size=SIZE_CONFIGS[args.size],
|
115 |
+
frame_num=1,
|
116 |
+
shift=args.sample_shift,
|
117 |
+
sample_solver=args.sample_solver,
|
118 |
+
sampling_steps=args.sample_steps,
|
119 |
+
guide_scale=args.sample_guide_scale,
|
120 |
+
seed=args.base_seed,
|
121 |
+
offload_model=args.offload_model)
|
122 |
+
except Exception as e:
|
123 |
+
logging.error(f"Error during video generation: {e}")
|
124 |
+
raise
|
125 |
|
126 |
else: # image-to-video
|
127 |
if args.prompt is None:
|
|
|
165 |
)
|
166 |
|
167 |
logging.info("Generating video ...")
|
168 |
+
try:
|
169 |
+
video = wan_i2v.generate(
|
170 |
+
args.prompt,
|
171 |
+
img,
|
172 |
+
max_area=MAX_AREA_CONFIGS[args.size],
|
173 |
+
frame_num=1,
|
174 |
+
shift=args.sample_shift,
|
175 |
+
sample_solver=args.sample_solver,
|
176 |
+
sampling_steps=args.sample_steps,
|
177 |
+
guide_scale=args.sample_guide_scale,
|
178 |
+
seed=args.base_seed,
|
179 |
+
offload_model=args.offload_model)
|
180 |
+
except Exception as e:
|
181 |
+
logging.error(f"Error during video generation: {e}")
|
182 |
+
raise
|
183 |
|
184 |
# Save the output video or image
|
185 |
if rank == 0:
|
|
|
189 |
suffix = '.png' if "t2i" in args.task else '.mp4'
|
190 |
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
191 |
|
192 |
+
try:
|
193 |
+
if "t2i" in args.task:
|
194 |
+
logging.info(f"Saving generated image to {args.save_file}")
|
195 |
+
cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True)
|
196 |
+
else:
|
197 |
+
logging.info(f"Saving generated video to {args.save_file}")
|
198 |
+
cache_video(tensor=video, save_file=args.save_file)
|
199 |
+
except Exception as e:
|
200 |
+
logging.error(f"Error saving output: {e}")
|
201 |
+
raise
|