rahul7star commited on
Commit
1446eb5
·
verified ·
1 Parent(s): a20dc48

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +48 -28
generate.py CHANGED
@@ -1,8 +1,16 @@
 
 
 
 
 
 
 
 
1
  def generate(args):
2
  rank = int(os.getenv("RANK", 0))
3
  world_size = int(os.getenv("WORLD_SIZE", 1))
4
  local_rank = int(os.getenv("LOCAL_RANK", 0))
5
-
6
  # Set device: use CPU if specified, else use GPU based on rank
7
  if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
8
  device = torch.device("cpu")
@@ -100,16 +108,20 @@ def generate(args):
100
  )
101
 
102
  logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
103
- video = wan_t2v.generate(
104
- args.prompt,
105
- size=SIZE_CONFIGS[args.size],
106
- frame_num=3,
107
- shift=args.sample_shift,
108
- sample_solver=args.sample_solver,
109
- sampling_steps=args.sample_steps,
110
- guide_scale=args.sample_guide_scale,
111
- seed=args.base_seed,
112
- offload_model=args.offload_model)
 
 
 
 
113
 
114
  else: # image-to-video
115
  if args.prompt is None:
@@ -153,17 +165,21 @@ def generate(args):
153
  )
154
 
155
  logging.info("Generating video ...")
156
- video = wan_i2v.generate(
157
- args.prompt,
158
- img,
159
- max_area=MAX_AREA_CONFIGS[args.size],
160
- frame_num=3,
161
- shift=args.sample_shift,
162
- sample_solver=args.sample_solver,
163
- sampling_steps=args.sample_steps,
164
- guide_scale=args.sample_guide_scale,
165
- seed=args.base_seed,
166
- offload_model=args.offload_model)
 
 
 
 
167
 
168
  # Save the output video or image
169
  if rank == 0:
@@ -173,9 +189,13 @@ def generate(args):
173
  suffix = '.png' if "t2i" in args.task else '.mp4'
174
  args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
175
 
176
- if "t2i" in args.task:
177
- logging.info(f"Saving generated image to {args.save_file}")
178
- cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True)
179
- else:
180
- logging.info(f"Saving generated video to {args.save_file}")
181
- cache_video(tensor=video, save_file=args.save_file)
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ import torch.distributed as dist
5
+ from PIL import Image
6
+ from datetime import datetime
7
+ from tqdm import tqdm
8
+
9
  def generate(args):
10
  rank = int(os.getenv("RANK", 0))
11
  world_size = int(os.getenv("WORLD_SIZE", 1))
12
  local_rank = int(os.getenv("LOCAL_RANK", 0))
13
+
14
  # Set device: use CPU if specified, else use GPU based on rank
15
  if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
16
  device = torch.device("cpu")
 
108
  )
109
 
110
  logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
111
+ try:
112
+ video = wan_t2v.generate(
113
+ args.prompt,
114
+ size=SIZE_CONFIGS[args.size],
115
+ frame_num=1,
116
+ shift=args.sample_shift,
117
+ sample_solver=args.sample_solver,
118
+ sampling_steps=args.sample_steps,
119
+ guide_scale=args.sample_guide_scale,
120
+ seed=args.base_seed,
121
+ offload_model=args.offload_model)
122
+ except Exception as e:
123
+ logging.error(f"Error during video generation: {e}")
124
+ raise
125
 
126
  else: # image-to-video
127
  if args.prompt is None:
 
165
  )
166
 
167
  logging.info("Generating video ...")
168
+ try:
169
+ video = wan_i2v.generate(
170
+ args.prompt,
171
+ img,
172
+ max_area=MAX_AREA_CONFIGS[args.size],
173
+ frame_num=1,
174
+ shift=args.sample_shift,
175
+ sample_solver=args.sample_solver,
176
+ sampling_steps=args.sample_steps,
177
+ guide_scale=args.sample_guide_scale,
178
+ seed=args.base_seed,
179
+ offload_model=args.offload_model)
180
+ except Exception as e:
181
+ logging.error(f"Error during video generation: {e}")
182
+ raise
183
 
184
  # Save the output video or image
185
  if rank == 0:
 
189
  suffix = '.png' if "t2i" in args.task else '.mp4'
190
  args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
191
 
192
+ try:
193
+ if "t2i" in args.task:
194
+ logging.info(f"Saving generated image to {args.save_file}")
195
+ cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True)
196
+ else:
197
+ logging.info(f"Saving generated video to {args.save_file}")
198
+ cache_video(tensor=video, save_file=args.save_file)
199
+ except Exception as e:
200
+ logging.error(f"Error saving output: {e}")
201
+ raise