svjack commited on
Commit
1bb2f87
·
verified ·
1 Parent(s): 63e3b4f

Upload 15 files

Browse files
fpack_cache_latents.py CHANGED
@@ -2,13 +2,14 @@ import argparse
2
  import logging
3
  import math
4
  import os
5
- from typing import List
6
 
7
  import numpy as np
8
  import torch
9
  import torch.nn.functional as F
10
  from tqdm import tqdm
11
  from transformers import SiglipImageProcessor, SiglipVisionModel
 
12
 
13
  from dataset import config_utils
14
  from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
@@ -28,15 +29,20 @@ def encode_and_save_batch(
28
  feature_extractor: SiglipImageProcessor,
29
  image_encoder: SiglipVisionModel,
30
  batch: List[ItemInfo],
31
- latent_window_size: int,
32
  vanilla_sampling: bool = False,
33
  one_frame: bool = False,
 
 
34
  ):
35
  """Encode a batch of original RGB videos and save FramePack section caches."""
36
  if one_frame:
37
- encode_and_save_batch_one_frame(vae, feature_extractor, image_encoder, batch, latent_window_size, vanilla_sampling)
 
 
38
  return
39
 
 
 
40
  # Stack batch into tensor (B,C,F,H,W) in RGB order
41
  contents = torch.stack([torch.from_numpy(item.content) for item in batch])
42
  if len(contents.shape) == 4:
@@ -238,34 +244,68 @@ def encode_and_save_batch_one_frame(
238
  feature_extractor: SiglipImageProcessor,
239
  image_encoder: SiglipVisionModel,
240
  batch: List[ItemInfo],
241
- latent_window_size: int,
242
  vanilla_sampling: bool = False,
 
 
243
  ):
244
  # item.content: target image (H, W, C)
245
- # item.control_content: start image (H, W, C)
246
-
247
- # Stack batch into tensor (B,F,H,W,C) in RGB order.
248
- contents = torch.stack(
249
- [torch.stack([torch.from_numpy(item.control_content), torch.from_numpy(item.content)]) for item in batch]
250
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
253
  contents = contents.to(vae.device, dtype=vae.dtype)
254
  contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
255
 
256
- height, width = contents.shape[3], contents.shape[4]
257
  if height < 8 or width < 8:
258
  item = batch[0] # other items should have the same size
259
  raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
260
 
261
- # VAE encode (list of tensor -> stack)
262
- start_latents = hunyuan.vae_encode(contents[:, :, 0:1], vae) # include scaling factor
263
- start_latents = start_latents.to("cpu") # (B, C, 1, H/8, W/8)
264
- latents = hunyuan.vae_encode(contents[:, :, 1:], vae) # include scaling factor
265
- latents = latents.to("cpu") # (B, C, 1, H/8, W/8)
 
 
 
 
 
 
266
 
267
  # Vision encoding per‑item (once): use control content because it is the start image
268
- images = [item.control_content for item in batch] # list of [H, W, C]
269
 
270
  # encode image with image encoder
271
  image_embeddings = []
@@ -276,56 +316,74 @@ def encode_and_save_batch_one_frame(
276
  image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
277
  image_embeddings = image_embeddings.to("cpu") # Save memory
278
 
279
- # history latents is always zeroes for one frame training
280
- history_latents = torch.zeros(
281
- (1, latents.shape[1], 1 + 2 + 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype
282
- ) # C=16 for HY
283
-
284
- # indices generation (same as inference)
285
- indices = torch.arange(0, sum([1, latent_window_size, 1, 2, 16])).unsqueeze(0)
286
- (
287
- clean_latent_indices_pre, # Index for start_latent
288
- latent_indices, # Indices for the target latents to predict
289
- clean_latent_indices_post, # Index for the most recent history frame
290
- clean_latent_2x_indices, # Indices for the next 2 history frames
291
- clean_latent_4x_indices, # Indices for the next 16 history frames
292
- ) = indices.split([1, latent_window_size, 1, 2, 16], dim=1)
293
-
294
- # Indices for clean_latents (start + recent history)
295
- latent_indices = latent_indices[:, -1:] # Only the last index is used for one frame training
296
- clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
297
-
298
- # clean latents preparation for all items (emulating inference)
299
- clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
300
-
301
  for b, item in enumerate(batch):
302
- original_latent_cache_path = item.latent_cache_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # clean latents preparation (emulating inference)
305
- clean_latents_pre = start_latents[b : b + 1]
306
- clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
 
 
307
 
308
  # Target latents for this section (ground truth)
309
- target_latents = latents[b : b + 1]
 
 
 
 
 
 
 
 
 
 
310
 
311
  # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
312
  save_latent_cache_framepack(
313
  item_info=item,
314
- latent=target_latents.squeeze(0), # Ground truth for this section
315
- latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
316
- clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
317
- clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
318
- clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
319
- clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
320
- clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
321
- clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
322
  image_embeddings=image_embeddings[b],
323
  )
324
 
325
 
326
  def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
327
  parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
328
- parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
329
  parser.add_argument(
330
  "--f1",
331
  action="store_true",
@@ -336,6 +394,16 @@ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.Argument
336
  action="store_true",
337
  help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
338
  )
 
 
 
 
 
 
 
 
 
 
339
  return parser
340
 
341
 
@@ -373,7 +441,9 @@ def main(args: argparse.Namespace):
373
 
374
  # encoding closure
375
  def encode(batch: List[ItemInfo]):
376
- encode_and_save_batch(vae, feature_extractor, image_encoder, batch, args.latent_window_size, args.f1, args.one_frame)
 
 
377
 
378
  # reuse core loop from cache_latents with no change
379
  encode_datasets_framepack(datasets, encode, args)
@@ -403,7 +473,7 @@ def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, arg
403
  all_existing = os.path.exists(item.latent_cache_path)
404
  else:
405
  latent_f = (item.frame_count - 1) // 4 + 1
406
- num_sections = max(1, math.floor((latent_f - 1) / args.latent_window_size)) # min 1 section
407
  all_existing = True
408
  for sec in range(num_sections):
409
  p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
 
2
  import logging
3
  import math
4
  import os
5
+ from typing import List, Optional
6
 
7
  import numpy as np
8
  import torch
9
  import torch.nn.functional as F
10
  from tqdm import tqdm
11
  from transformers import SiglipImageProcessor, SiglipVisionModel
12
+ from PIL import Image
13
 
14
  from dataset import config_utils
15
  from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
 
29
  feature_extractor: SiglipImageProcessor,
30
  image_encoder: SiglipVisionModel,
31
  batch: List[ItemInfo],
 
32
  vanilla_sampling: bool = False,
33
  one_frame: bool = False,
34
+ one_frame_no_2x: bool = False,
35
+ one_frame_no_4x: bool = False,
36
  ):
37
  """Encode a batch of original RGB videos and save FramePack section caches."""
38
  if one_frame:
39
+ encode_and_save_batch_one_frame(
40
+ vae, feature_extractor, image_encoder, batch, vanilla_sampling, one_frame_no_2x, one_frame_no_4x
41
+ )
42
  return
43
 
44
+ latent_window_size = batch[0].fp_latent_window_size # all items should have the same window size
45
+
46
  # Stack batch into tensor (B,C,F,H,W) in RGB order
47
  contents = torch.stack([torch.from_numpy(item.content) for item in batch])
48
  if len(contents.shape) == 4:
 
244
  feature_extractor: SiglipImageProcessor,
245
  image_encoder: SiglipVisionModel,
246
  batch: List[ItemInfo],
 
247
  vanilla_sampling: bool = False,
248
+ one_frame_no_2x: bool = False,
249
+ one_frame_no_4x: bool = False,
250
  ):
251
  # item.content: target image (H, W, C)
252
+ # item.control_content: list of images (H, W, C)
253
+
254
+ # Stack batch into tensor (B,F,H,W,C) in RGB order. The numbers of control content for each item are the same.
255
+ contents = []
256
+ content_masks: list[list[Optional[torch.Tensor]]] = []
257
+ for item in batch:
258
+ item_contents = item.control_content + [item.content]
259
+
260
+ item_masks = []
261
+ for i, c in enumerate(item_contents):
262
+ if c.shape[-1] == 4: # RGBA
263
+ item_contents[i] = c[..., :3] # remove alpha channel from content
264
+
265
+ alpha = c[..., 3] # extract alpha channel
266
+ mask_image = Image.fromarray(alpha, mode="L")
267
+ width, height = mask_image.size
268
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
269
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
270
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
271
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
272
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
273
+ mask_image = mask_image.to(torch.float32)
274
+ content_mask = mask_image
275
+ else:
276
+ content_mask = None
277
+
278
+ item_masks.append(content_mask)
279
+
280
+ item_contents = [torch.from_numpy(c) for c in item_contents]
281
+ contents.append(torch.stack(item_contents, dim=0)) # list of [F, H, W, C]
282
+ content_masks.append(item_masks)
283
+
284
+ contents = torch.stack(contents, dim=0) # B, F, H, W, C. F is control frames + target frame
285
 
286
  contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
287
  contents = contents.to(vae.device, dtype=vae.dtype)
288
  contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
289
 
290
+ height, width = contents.shape[-2], contents.shape[-1]
291
  if height < 8 or width < 8:
292
  item = batch[0] # other items should have the same size
293
  raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
294
 
295
+ # VAE encode: we need to encode one frame at a time because VAE encoder has stride=4 for the time dimension except for the first frame.
296
+ latents = [hunyuan.vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])]
297
+ latents = torch.cat(latents, dim=2) # B, C, F, H/8, W/8
298
+
299
+ # apply alphas to latents
300
+ for b, item in enumerate(batch):
301
+ for i, content_mask in enumerate(content_masks[b]):
302
+ if content_mask is not None:
303
+ # apply mask to the latents
304
+ # print(f"Applying content mask for item {item.item_key}, frame {i}")
305
+ latents[b : b + 1, :, i : i + 1] *= content_mask
306
 
307
  # Vision encoding per‑item (once): use control content because it is the start image
308
+ images = [item.control_content[0] for item in batch] # list of [H, W, C]
309
 
310
  # encode image with image encoder
311
  image_embeddings = []
 
316
  image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
317
  image_embeddings = image_embeddings.to("cpu") # Save memory
318
 
319
+ # save cache for each item in the batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  for b, item in enumerate(batch):
321
+ # indices generation (same as inference): each item may have different clean_latent_indices, so we generate them per item
322
+ clean_latent_indices = item.fp_1f_clean_indices # list of indices for clean latents
323
+ if clean_latent_indices is None or len(clean_latent_indices) == 0:
324
+ logger.warning(
325
+ f"Item {item.item_key} has no clean_latent_indices defined, using default indices for one frame training."
326
+ )
327
+ clean_latent_indices = [0]
328
+
329
+ if not item.fp_1f_no_post:
330
+ clean_latent_indices = clean_latent_indices + [1 + item.fp_latent_window_size]
331
+ clean_latent_indices = torch.Tensor(clean_latent_indices).long() # N
332
+
333
+ latent_index = torch.Tensor([item.fp_1f_target_index]).long() # 1
334
+
335
+ # zero values is not needed to cache even if one_frame_no_2x or 4x is False
336
+ clean_latents_2x = None
337
+ clean_latents_4x = None
338
+
339
+ if one_frame_no_2x:
340
+ clean_latent_2x_indices = None
341
+ else:
342
+ index = 1 + item.fp_latent_window_size + 1
343
+ clean_latent_2x_indices = torch.arange(index, index + 2) # 2
344
+
345
+ if one_frame_no_4x:
346
+ clean_latent_4x_indices = None
347
+ else:
348
+ index = 1 + item.fp_latent_window_size + 1 + 2
349
+ clean_latent_4x_indices = torch.arange(index, index + 16) # 16
350
 
351
  # clean latents preparation (emulating inference)
352
+ clean_latents = latents[b, :, :-1] # C, F, H, W
353
+ if not item.fp_1f_no_post:
354
+ # If zero post is enabled, we need to add a zero frame at the end
355
+ clean_latents = F.pad(clean_latents, (0, 0, 0, 0, 0, 1), value=0.0) # C, F+1, H, W
356
 
357
  # Target latents for this section (ground truth)
358
+ target_latents = latents[b, :, -1:] # C, 1, H, W
359
+
360
+ print(f"Saving cache for item {item.item_key} at {item.latent_cache_path}. no_post: {item.fp_1f_no_post}")
361
+ print(f" Clean latent indices: {clean_latent_indices}, latent index: {latent_index}")
362
+ print(f" Clean latents: {clean_latents.shape}, target latents: {target_latents.shape}")
363
+ print(f" Clean latents 2x indices: {clean_latent_2x_indices}, clean latents 4x indices: {clean_latent_4x_indices}")
364
+ print(
365
+ f" Clean latents 2x: {clean_latents_2x.shape if clean_latents_2x is not None else 'None'}, "
366
+ f"Clean latents 4x: {clean_latents_4x.shape if clean_latents_4x is not None else 'None'}"
367
+ )
368
+ print(f" Image embeddings: {image_embeddings[b].shape}")
369
 
370
  # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
371
  save_latent_cache_framepack(
372
  item_info=item,
373
+ latent=target_latents, # Ground truth for this section
374
+ latent_indices=latent_index, # Indices for the ground truth section
375
+ clean_latents=clean_latents, # Start frame + history placeholder
376
+ clean_latent_indices=clean_latent_indices, # Indices for start frame + history placeholder
377
+ clean_latents_2x=clean_latents_2x, # History placeholder
378
+ clean_latent_2x_indices=clean_latent_2x_indices, # Indices for history placeholder
379
+ clean_latents_4x=clean_latents_4x, # History placeholder
380
+ clean_latent_4x_indices=clean_latent_4x_indices, # Indices for history placeholder
381
  image_embeddings=image_embeddings[b],
382
  )
383
 
384
 
385
  def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
386
  parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
 
387
  parser.add_argument(
388
  "--f1",
389
  action="store_true",
 
394
  action="store_true",
395
  help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
396
  )
397
+ parser.add_argument(
398
+ "--one_frame_no_2x",
399
+ action="store_true",
400
+ help="Do not use clean_latents_2x and clean_latent_2x_indices for one frame training.",
401
+ )
402
+ parser.add_argument(
403
+ "--one_frame_no_4x",
404
+ action="store_true",
405
+ help="Do not use clean_latents_4x and clean_latent_4x_indices for one frame training.",
406
+ )
407
  return parser
408
 
409
 
 
441
 
442
  # encoding closure
443
  def encode(batch: List[ItemInfo]):
444
+ encode_and_save_batch(
445
+ vae, feature_extractor, image_encoder, batch, args.f1, args.one_frame, args.one_frame_no_2x, args.one_frame_no_4x
446
+ )
447
 
448
  # reuse core loop from cache_latents with no change
449
  encode_datasets_framepack(datasets, encode, args)
 
473
  all_existing = os.path.exists(item.latent_cache_path)
474
  else:
475
  latent_f = (item.frame_count - 1) // 4 + 1
476
+ num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size)) # min 1 section
477
  all_existing = True
478
  for sec in range(num_sections):
479
  p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
fpack_generate_video.py CHANGED
@@ -114,20 +114,17 @@ def parse_args() -> argparse.Namespace:
114
  "--one_frame_inference",
115
  type=str,
116
  default=None,
117
- help="one frame inference, default is None, comma separated values from 'zero_post', 'no_2x', 'no_4x' and 'no_post'.",
118
  )
119
  parser.add_argument(
120
- "--image_mask_path",
121
- type=str,
122
- default=None,
123
- help="path to image mask for one frame inference. If specified, it will be used as mask for input image.",
124
  )
125
  parser.add_argument(
126
- "--end_image_mask_path",
127
  type=str,
128
  default=None,
129
  nargs="*",
130
- help="path to end (reference) image mask for one frame inference. If specified, it will be used as mask for end image.",
131
  )
132
  parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
133
  parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
@@ -154,7 +151,7 @@ def parse_args() -> argparse.Namespace:
154
  default=None,
155
  help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
156
  )
157
- parser.add_argument("--end_image_path", type=str, nargs="*", default=None, help="path to end image for image2video inference")
158
  parser.add_argument(
159
  "--latent_paddings",
160
  type=str,
@@ -180,6 +177,16 @@ def parse_args() -> argparse.Namespace:
180
  parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
181
  parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
182
  # parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
 
 
 
 
 
 
 
 
 
 
183
  parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
184
  parser.add_argument(
185
  "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
@@ -248,9 +255,9 @@ def parse_prompt_line(line: str) -> Dict[str, Any]:
248
 
249
  # Create dictionary of overrides
250
  overrides = {"prompt": prompt}
251
- # Initialize end_image_path and end_image_mask_path as a list to accommodate multiple paths
252
- overrides["end_image_path"] = []
253
- overrides["end_image_mask_path"] = []
254
 
255
  for part in parts[1:]:
256
  if not part.strip():
@@ -276,8 +283,8 @@ def parse_prompt_line(line: str) -> Dict[str, Any]:
276
  # overrides["flow_shift"] = float(value)
277
  elif option == "i":
278
  overrides["image_path"] = value
279
- elif option == "im":
280
- overrides["image_mask_path"] = value
281
  # elif option == "cn":
282
  # overrides["control_path"] = value
283
  elif option == "n":
@@ -285,17 +292,19 @@ def parse_prompt_line(line: str) -> Dict[str, Any]:
285
  elif option == "vs": # video_sections
286
  overrides["video_sections"] = int(value)
287
  elif option == "ei": # end_image_path
288
- overrides["end_image_path"].append(value)
289
- elif option == "eim": # end_image_mask_path
290
- overrides["end_image_mask_path"].append(value)
 
 
291
  elif option == "of": # one_frame_inference
292
  overrides["one_frame_inference"] = value
293
 
294
- # If no end_image_path was provided, remove the empty list
295
- if not overrides["end_image_path"]:
296
- del overrides["end_image_path"]
297
- if not overrides["end_image_mask_path"]:
298
- del overrides["end_image_mask_path"]
299
 
300
  return overrides
301
 
@@ -366,6 +375,13 @@ def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVid
366
 
367
  # do not fp8 optimize because we will merge LoRA weights
368
  model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
 
 
 
 
 
 
 
369
  return model
370
 
371
 
@@ -558,30 +574,44 @@ def prepare_i2v_inputs(
558
 
559
  # prepare image
560
  def preprocess_image(image_path: str):
561
- image = Image.open(image_path).convert("RGB")
 
 
 
 
 
562
 
563
  image_np = np.array(image) # PIL to numpy, HWC
564
 
565
  image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
566
  image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
567
  image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
568
- return image_tensor, image_np
569
 
570
  section_image_paths = parse_section_strings(args.image_path)
571
 
572
  section_images = {}
573
  for index, image_path in section_image_paths.items():
574
- img_tensor, img_np = preprocess_image(image_path)
575
  section_images[index] = (img_tensor, img_np)
576
 
 
 
 
 
 
 
577
  # check end images
578
- if args.end_image_path is not None and len(args.end_image_path) > 0:
579
- end_image_tensors = []
580
- for end_img_path in args.end_image_path:
581
- end_image_tensor, _ = preprocess_image(end_img_path)
582
- end_image_tensors.append(end_image_tensor)
 
 
583
  else:
584
- end_image_tensors = None
 
585
 
586
  # configure negative prompt
587
  n_prompt = args.negative_prompt if args.negative_prompt else ""
@@ -644,6 +674,7 @@ def prepare_i2v_inputs(
644
  image_encoder.to(device)
645
 
646
  # encode image with image encoder
 
647
  section_image_encoder_last_hidden_states = {}
648
  for index, (img_tensor, img_np) in section_images.items():
649
  with torch.no_grad():
@@ -666,14 +697,14 @@ def prepare_i2v_inputs(
666
  start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
667
  section_start_latents[index] = start_latent
668
 
669
- # end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
670
- if end_image_tensors is not None:
671
- end_latents = []
672
- for end_image_tensor in end_image_tensors:
673
- end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu()
674
- end_latents.append(end_latent)
675
- else:
676
- end_latents = None
677
 
678
  vae.to("cpu") # move VAE to CPU to save memory
679
  clean_memory_on_device(device)
@@ -710,7 +741,7 @@ def prepare_i2v_inputs(
710
  }
711
  arg_c_img[index] = arg_c_img_i
712
 
713
- return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latents
714
 
715
 
716
  # def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
@@ -930,13 +961,15 @@ def generate(
930
  if shared_models is not None:
931
  # Use shared models and encoded data
932
  vae = shared_models.get("vae")
933
- height, width, video_seconds, context, context_null, context_img, end_latents = prepare_i2v_inputs(
934
- args, device, vae, shared_models
935
  )
936
  else:
937
  # prepare inputs without shared models
938
  vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
939
- height, width, video_seconds, context, context_null, context_img, end_latents = prepare_i2v_inputs(args, device, vae)
 
 
940
 
941
  if shared_models is None or "model" not in shared_models:
942
  # load DiT model
@@ -986,294 +1019,231 @@ def generate(
986
  for mode in args.one_frame_inference.split(","):
987
  one_frame_inference.add(mode.strip())
988
 
989
- # prepare history latents
990
- history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
991
- if end_latents is not None and not f1_mode:
992
- logger.info(f"Use end image(s): {args.end_image_path}")
993
- for i, end_latent in enumerate(end_latents):
994
- history_latents[:, :, i + 1 : i + 2] = end_latent.to(history_latents)
995
-
996
- # prepare clean latents and indices
997
- if not f1_mode:
998
- # Inverted Anti-drifting
999
- total_generated_latent_frames = 0
1000
- latent_paddings = reversed(range(total_latent_sections))
1001
-
1002
- if total_latent_sections > 4 and one_frame_inference is None:
1003
- # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
1004
- # items looks better than expanding it when total_latent_sections > 4
1005
- # One can try to remove below trick and just
1006
- # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
1007
- # 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
1008
- latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
1009
-
1010
- if args.latent_paddings is not None:
1011
- # parse user defined latent paddings
1012
- user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
1013
- if len(user_latent_paddings) < total_latent_sections:
1014
- print(
1015
- f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
1016
- )
1017
- print(f"Use default paddings instead for unspecified sections.")
1018
- latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
1019
- elif len(user_latent_paddings) > total_latent_sections:
1020
- print(
1021
- f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
1022
- )
1023
- print(f"Use only first {total_latent_sections} paddings instead.")
1024
- latent_paddings = user_latent_paddings[:total_latent_sections]
1025
- else:
1026
- latent_paddings = user_latent_paddings
1027
  else:
1028
- start_latent = context_img[0]["start_latent"]
1029
- history_latents = torch.cat([history_latents, start_latent], dim=2)
1030
- total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
1031
- latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
1032
-
1033
- latent_paddings = list(latent_paddings) # make sure it's a list
1034
- for loop_index in range(total_latent_sections):
1035
- latent_padding = latent_paddings[loop_index]
1036
 
 
1037
  if not f1_mode:
1038
  # Inverted Anti-drifting
1039
- section_index_reverse = loop_index # 0, 1, 2, 3
1040
- section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
1041
- section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
1042
-
1043
- is_last_section = section_index == 0
1044
- is_first_section = section_index_reverse == 0
1045
- latent_padding_size = latent_padding * latent_window_size
1046
-
1047
- logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
1048
- else:
1049
- section_index = loop_index # 0, 1, 2, 3
1050
- section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
1051
- is_last_section = loop_index == total_latent_sections - 1
1052
- is_first_section = loop_index == 0
1053
- latent_padding_size = 0 # dummy padding for F1 mode
1054
-
1055
- # select start latent
1056
- if section_index_from_last in context_img:
1057
- image_index = section_index_from_last
1058
- elif section_index in context_img:
1059
- image_index = section_index
1060
- else:
1061
- image_index = 0
1062
-
1063
- start_latent = context_img[image_index]["start_latent"]
1064
- image_path = context_img[image_index]["image_path"]
1065
- if image_index != 0: # use section image other than section 0
1066
- logger.info(f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}")
1067
-
1068
- if not f1_mode:
1069
- # Inverted Anti-drifting
1070
- indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
1071
- (
1072
- clean_latent_indices_pre,
1073
- blank_indices,
1074
- latent_indices,
1075
- clean_latent_indices_post,
1076
- clean_latent_2x_indices,
1077
- clean_latent_4x_indices,
1078
- ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
1079
- clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
1080
-
1081
- clean_latents_pre = start_latent.to(history_latents)
1082
- clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
1083
- [1, 2, 16], dim=2
1084
- )
1085
- clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
1086
-
1087
- if end_latents is not None:
1088
- clean_latents = torch.cat([clean_latents_pre, history_latents[:, :, : len(end_latents)]], dim=2)
1089
- clean_latent_indices_extended = torch.zeros(1, 1 + len(end_latents), dtype=clean_latent_indices.dtype)
1090
- clean_latent_indices_extended[:, :2] = clean_latent_indices
1091
- clean_latent_indices = clean_latent_indices_extended
1092
-
1093
- else:
1094
- # F1 mode
1095
- indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
1096
- (
1097
- clean_latent_indices_start,
1098
- clean_latent_4x_indices,
1099
- clean_latent_2x_indices,
1100
- clean_latent_1x_indices,
1101
- latent_indices,
1102
- ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
1103
- clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
1104
-
1105
- clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
1106
- [16, 2, 1], dim=2
1107
- )
1108
- clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
1109
-
1110
- # if use_teacache:
1111
- # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
1112
- # else:
1113
- # transformer.initialize_teacache(enable_teacache=False)
1114
-
1115
- # prepare conditioning inputs
1116
- if section_index_from_last in context:
1117
- prompt_index = section_index_from_last
1118
- elif section_index in context:
1119
- prompt_index = section_index
1120
  else:
1121
- prompt_index = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1122
 
1123
- context_for_index = context[prompt_index]
1124
- # if args.section_prompts is not None:
1125
- logger.info(f"Section {section_index}: {context_for_index['prompt']}")
 
 
 
1126
 
1127
- llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
1128
- llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
1129
- clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
 
1131
- image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
1132
- device, dtype=torch.bfloat16
1133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1134
 
1135
- llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
1136
- llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
1137
- clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1138
 
1139
- # call DiT model to generate latents
1140
- sample_num_frames = num_frames
1141
- if one_frame_inference is not None:
1142
- # one frame inference
1143
- latent_indices = latent_indices[:, -1:] # only use the last frame (default)
1144
- sample_num_frames = 1
1145
-
1146
- def get_latent_mask(mask_path: str):
1147
- mask_image = Image.open(mask_path).convert("L") # grayscale
1148
- mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
1149
- mask_image = np.array(mask_image) # PIL to numpy, HWC
1150
- mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
1151
- mask_image = mask_image.squeeze(-1) # HWC -> HW
1152
- mask_image = mask_image.unsqueeze(0).unsqueeze(0) # HW -> 11HW
1153
- mask_image = mask_image.to(clean_latents)
1154
- return mask_image
1155
-
1156
- if args.image_mask_path is not None:
1157
- mask_image = get_latent_mask(args.image_mask_path)
1158
- logger.info(f"Apply mask for clean latents (start image): {args.image_mask_path}, shape: {mask_image.shape}")
1159
- clean_latents[:, :, 0, :, :] = clean_latents[:, :, 0, :, :] * mask_image
1160
- if args.end_image_mask_path is not None and len(args.end_image_mask_path) > 0:
1161
- # # apply mask for clean latents 1x (end image)
1162
- count = min(len(args.end_image_mask_path), len(end_latents))
1163
- for i in range(count):
1164
- mask_image = get_latent_mask(args.end_image_mask_path[i])
1165
- logger.info(
1166
- f"Apply mask for clean latents 1x (end image) for {i+1}: {args.end_image_mask_path[i]}, shape: {mask_image.shape}"
1167
- )
1168
- clean_latents[:, :, i + 1 : i + 2, :, :] = clean_latents[:, :, i + 1 : i + 2, :, :] * mask_image
1169
-
1170
- for one_frame_param in one_frame_inference:
1171
- if one_frame_param.startswith("target_index="):
1172
- target_index = int(one_frame_param.split("=")[1])
1173
- latent_indices[:, 0] = target_index
1174
- logger.info(f"Set index for target: {target_index}")
1175
- elif one_frame_param.startswith("start_index="):
1176
- start_index = int(one_frame_param.split("=")[1])
1177
- clean_latent_indices[:, 0] = start_index
1178
- logger.info(f"Set index for clean latent pre (start image): {start_index}")
1179
- elif one_frame_param.startswith("history_index="):
1180
- history_indices = one_frame_param.split("=")[1].split(";")
1181
- i = 0
1182
- while i < len(history_indices) and i < len(end_latents):
1183
- history_index = int(history_indices[i])
1184
- clean_latent_indices[:, 1 + i] = history_index
1185
- i += 1
1186
- while i < len(end_latents):
1187
- clean_latent_indices[:, 1 + i] = history_index
1188
- i += 1
1189
- logger.info(f"Set index for clean latent post (end image): {history_indices}")
1190
-
1191
- if "no_2x" in one_frame_inference:
1192
- clean_latents_2x = None
1193
- clean_latent_2x_indices = None
1194
- logger.info(f"No clean_latents_2x")
1195
- if "no_4x" in one_frame_inference:
1196
- clean_latents_4x = None
1197
- clean_latent_4x_indices = None
1198
- logger.info(f"No clean_latents_4x")
1199
- if "no_post" in one_frame_inference:
1200
- clean_latents = clean_latents[:, :, :1, :, :]
1201
- clean_latent_indices = clean_latent_indices[:, :1]
1202
- logger.info(f"No clean_latents post")
1203
- elif "zero_post" in one_frame_inference:
1204
- # zero out the history latents. this seems to prevent the images from corrupting
1205
- clean_latents[:, :, 1:, :, :] = torch.zeros_like(clean_latents[:, :, 1:, :, :])
1206
- logger.info(f"Zero out clean_latents post")
1207
 
1208
- logger.info(
1209
- f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
1210
  )
1211
 
1212
- generated_latents = sample_hunyuan(
1213
- transformer=model,
1214
- sampler=args.sample_solver,
1215
- width=width,
1216
- height=height,
1217
- frames=sample_num_frames,
1218
- real_guidance_scale=args.guidance_scale,
1219
- distilled_guidance_scale=args.embedded_cfg_scale,
1220
- guidance_rescale=args.guidance_rescale,
1221
- # shift=3.0,
1222
- num_inference_steps=args.infer_steps,
1223
- generator=seed_g,
1224
- prompt_embeds=llama_vec,
1225
- prompt_embeds_mask=llama_attention_mask,
1226
- prompt_poolers=clip_l_pooler,
1227
- negative_prompt_embeds=llama_vec_n,
1228
- negative_prompt_embeds_mask=llama_attention_mask_n,
1229
- negative_prompt_poolers=clip_l_pooler_n,
1230
- device=device,
1231
- dtype=torch.bfloat16,
1232
- image_embeddings=image_encoder_last_hidden_state,
1233
- latent_indices=latent_indices,
1234
- clean_latents=clean_latents,
1235
- clean_latent_indices=clean_latent_indices,
1236
- clean_latents_2x=clean_latents_2x,
1237
- clean_latent_2x_indices=clean_latent_2x_indices,
1238
- clean_latents_4x=clean_latents_4x,
1239
- clean_latent_4x_indices=clean_latent_4x_indices,
1240
- )
1241
-
1242
- # concatenate generated latents
1243
- total_generated_latent_frames += int(generated_latents.shape[2])
1244
- if not f1_mode:
1245
- # Inverted Anti-drifting: prepend generated latents to history latents
1246
- if is_last_section:
1247
- generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
1248
- total_generated_latent_frames += 1
1249
 
1250
- history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
1251
- real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
1252
- else:
1253
- # F1 mode: append generated latents to history latents
1254
- history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
1255
- real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
1256
-
1257
- logger.info(f"Generated. Latent shape {real_history_latents.shape}")
1258
-
1259
- # # TODO support saving intermediate video
1260
- # clean_memory_on_device(device)
1261
- # vae.to(device)
1262
- # if history_pixels is None:
1263
- # history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
1264
- # else:
1265
- # section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
1266
- # overlapped_frames = latent_window_size * 4 - 3
1267
- # current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
1268
- # history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
1269
- # vae.to("cpu")
1270
- # # if not is_last_section:
1271
- # # # save intermediate video
1272
- # # save_video(history_pixels[0], args, total_generated_latent_frames)
1273
- # print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
1274
 
1275
- if one_frame_inference is not None:
1276
- real_history_latents = real_history_latents[:, :, 1:, :, :] # remove the first frame (start_latent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1277
 
1278
  # Only clean up shared models if they were created within this function
1279
  if shared_models is None:
@@ -1284,8 +1254,9 @@ def generate(
1284
  model.to("cpu")
1285
 
1286
  # wait for 5 seconds until block swap is done
1287
- logger.info("Waiting for 5 seconds to finish block swap")
1288
- time.sleep(5)
 
1289
 
1290
  gc.collect()
1291
  clean_memory_on_device(device)
@@ -1293,6 +1264,156 @@ def generate(
1293
  return vae, real_history_latents
1294
 
1295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1296
  def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
1297
  """Save latent to file
1298
 
 
114
  "--one_frame_inference",
115
  type=str,
116
  default=None,
117
+ help="one frame inference, default is None, comma separated values from 'no_2x', 'no_4x', 'no_post', 'control_indices' and 'target_index'.",
118
  )
119
  parser.add_argument(
120
+ "--control_image_path", type=str, default=None, nargs="*", help="path to control (reference) image for one frame inference."
 
 
 
121
  )
122
  parser.add_argument(
123
+ "--control_image_mask_path",
124
  type=str,
125
  default=None,
126
  nargs="*",
127
+ help="path to control (reference) image mask for one frame inference.",
128
  )
129
  parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
130
  parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
 
151
  default=None,
152
  help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
153
  )
154
+ parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
155
  parser.add_argument(
156
  "--latent_paddings",
157
  type=str,
 
177
  parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
178
  parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
179
  # parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
180
+ parser.add_argument(
181
+ "--rope_scaling_factor", type=float, default=0.5, help="RoPE scaling factor for high resolution (H/W), default is 0.5"
182
+ )
183
+ parser.add_argument(
184
+ "--rope_scaling_timestep_threshold",
185
+ type=int,
186
+ default=None,
187
+ help="RoPE scaling timestep threshold, default is None (disable), if set, RoPE scaling will be applied only for timesteps >= threshold, around 800 is good starting point",
188
+ )
189
+
190
  parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
191
  parser.add_argument(
192
  "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
 
255
 
256
  # Create dictionary of overrides
257
  overrides = {"prompt": prompt}
258
+ # Initialize control_image_path and control_image_mask_path as a list to accommodate multiple paths
259
+ overrides["control_image_path"] = []
260
+ overrides["control_image_mask_path"] = []
261
 
262
  for part in parts[1:]:
263
  if not part.strip():
 
283
  # overrides["flow_shift"] = float(value)
284
  elif option == "i":
285
  overrides["image_path"] = value
286
+ # elif option == "im":
287
+ # overrides["image_mask_path"] = value
288
  # elif option == "cn":
289
  # overrides["control_path"] = value
290
  elif option == "n":
 
292
  elif option == "vs": # video_sections
293
  overrides["video_sections"] = int(value)
294
  elif option == "ei": # end_image_path
295
+ overrides["end_image_path"] = value
296
+ elif option == "ci": # control_image_path
297
+ overrides["control_image_path"].append(value)
298
+ elif option == "cim": # control_image_mask_path
299
+ overrides["control_image_mask_path"].append(value)
300
  elif option == "of": # one_frame_inference
301
  overrides["one_frame_inference"] = value
302
 
303
+ # If no control_image_path was provided, remove the empty list
304
+ if not overrides["control_image_path"]:
305
+ del overrides["control_image_path"]
306
+ if not overrides["control_image_mask_path"]:
307
+ del overrides["control_image_mask_path"]
308
 
309
  return overrides
310
 
 
375
 
376
  # do not fp8 optimize because we will merge LoRA weights
377
  model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
378
+
379
+ # apply RoPE scaling factor
380
+ if args.rope_scaling_timestep_threshold is not None:
381
+ logger.info(
382
+ f"Applying RoPE scaling factor {args.rope_scaling_factor} for timesteps >= {args.rope_scaling_timestep_threshold}"
383
+ )
384
+ model.enable_rope_scaling(args.rope_scaling_timestep_threshold, args.rope_scaling_factor)
385
  return model
386
 
387
 
 
574
 
575
  # prepare image
576
  def preprocess_image(image_path: str):
577
+ image = Image.open(image_path)
578
+ if image.mode == "RGBA":
579
+ alpha = image.split()[-1]
580
+ else:
581
+ alpha = None
582
+ image = image.convert("RGB")
583
 
584
  image_np = np.array(image) # PIL to numpy, HWC
585
 
586
  image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
587
  image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
588
  image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
589
+ return image_tensor, image_np, alpha
590
 
591
  section_image_paths = parse_section_strings(args.image_path)
592
 
593
  section_images = {}
594
  for index, image_path in section_image_paths.items():
595
+ img_tensor, img_np, _ = preprocess_image(image_path)
596
  section_images[index] = (img_tensor, img_np)
597
 
598
+ # check end image
599
+ if args.end_image_path is not None:
600
+ end_image_tensor, _, _ = preprocess_image(args.end_image_path)
601
+ else:
602
+ end_image_tensor = None
603
+
604
  # check end images
605
+ if args.control_image_path is not None and len(args.control_image_path) > 0:
606
+ control_image_tensors = []
607
+ control_mask_images = []
608
+ for ctrl_image_path in args.control_image_path:
609
+ control_image_tensor, _, control_mask = preprocess_image(ctrl_image_path)
610
+ control_image_tensors.append(control_image_tensor)
611
+ control_mask_images.append(control_mask)
612
  else:
613
+ control_image_tensors = None
614
+ control_mask_images = None
615
 
616
  # configure negative prompt
617
  n_prompt = args.negative_prompt if args.negative_prompt else ""
 
674
  image_encoder.to(device)
675
 
676
  # encode image with image encoder
677
+
678
  section_image_encoder_last_hidden_states = {}
679
  for index, (img_tensor, img_np) in section_images.items():
680
  with torch.no_grad():
 
697
  start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
698
  section_start_latents[index] = start_latent
699
 
700
+ end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
701
+
702
+ control_latents = None
703
+ if control_image_tensors is not None:
704
+ control_latents = []
705
+ for ctrl_image_tensor in control_image_tensors:
706
+ control_latent = hunyuan.vae_encode(ctrl_image_tensor, vae).cpu()
707
+ control_latents.append(control_latent)
708
 
709
  vae.to("cpu") # move VAE to CPU to save memory
710
  clean_memory_on_device(device)
 
741
  }
742
  arg_c_img[index] = arg_c_img_i
743
 
744
+ return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latent, control_latents, control_mask_images
745
 
746
 
747
  # def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
 
961
  if shared_models is not None:
962
  # Use shared models and encoded data
963
  vae = shared_models.get("vae")
964
+ height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
965
+ prepare_i2v_inputs(args, device, vae, shared_models)
966
  )
967
  else:
968
  # prepare inputs without shared models
969
  vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
970
+ height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
971
+ prepare_i2v_inputs(args, device, vae)
972
+ )
973
 
974
  if shared_models is None or "model" not in shared_models:
975
  # load DiT model
 
1019
  for mode in args.one_frame_inference.split(","):
1020
  one_frame_inference.add(mode.strip())
1021
 
1022
+ if one_frame_inference is not None:
1023
+ real_history_latents = generate_with_one_frame_inference(
1024
+ args,
1025
+ model,
1026
+ context,
1027
+ context_null,
1028
+ context_img,
1029
+ control_latents,
1030
+ control_mask_images,
1031
+ latent_window_size,
1032
+ height,
1033
+ width,
1034
+ device,
1035
+ seed_g,
1036
+ one_frame_inference,
1037
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1038
  else:
1039
+ # prepare history latents
1040
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
1041
+ if end_latent is not None and not f1_mode:
1042
+ logger.info(f"Use end image(s): {args.end_image_path}")
1043
+ history_latents[:, :, :1] = end_latent.to(history_latents)
 
 
 
1044
 
1045
+ # prepare clean latents and indices
1046
  if not f1_mode:
1047
  # Inverted Anti-drifting
1048
+ total_generated_latent_frames = 0
1049
+ latent_paddings = reversed(range(total_latent_sections))
1050
+
1051
+ if total_latent_sections > 4 and one_frame_inference is None:
1052
+ # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
1053
+ # items looks better than expanding it when total_latent_sections > 4
1054
+ # One can try to remove below trick and just
1055
+ # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
1056
+ # 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
1057
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
1058
+
1059
+ if args.latent_paddings is not None:
1060
+ # parse user defined latent paddings
1061
+ user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
1062
+ if len(user_latent_paddings) < total_latent_sections:
1063
+ print(
1064
+ f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
1065
+ )
1066
+ print(f"Use default paddings instead for unspecified sections.")
1067
+ latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
1068
+ elif len(user_latent_paddings) > total_latent_sections:
1069
+ print(
1070
+ f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
1071
+ )
1072
+ print(f"Use only first {total_latent_sections} paddings instead.")
1073
+ latent_paddings = user_latent_paddings[:total_latent_sections]
1074
+ else:
1075
+ latent_paddings = user_latent_paddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1076
  else:
1077
+ start_latent = context_img[0]["start_latent"]
1078
+ history_latents = torch.cat([history_latents, start_latent], dim=2)
1079
+ total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
1080
+ latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
1081
+
1082
+ latent_paddings = list(latent_paddings) # make sure it's a list
1083
+ for loop_index in range(total_latent_sections):
1084
+ latent_padding = latent_paddings[loop_index]
1085
+
1086
+ if not f1_mode:
1087
+ # Inverted Anti-drifting
1088
+ section_index_reverse = loop_index # 0, 1, 2, 3
1089
+ section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
1090
+ section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
1091
+
1092
+ is_last_section = section_index == 0
1093
+ is_first_section = section_index_reverse == 0
1094
+ latent_padding_size = latent_padding * latent_window_size
1095
+
1096
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
1097
+ else:
1098
+ section_index = loop_index # 0, 1, 2, 3
1099
+ section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
1100
+ is_last_section = loop_index == total_latent_sections - 1
1101
+ is_first_section = loop_index == 0
1102
+ latent_padding_size = 0 # dummy padding for F1 mode
1103
+
1104
+ # select start latent
1105
+ if section_index_from_last in context_img:
1106
+ image_index = section_index_from_last
1107
+ elif section_index in context_img:
1108
+ image_index = section_index
1109
+ else:
1110
+ image_index = 0
1111
 
1112
+ start_latent = context_img[image_index]["start_latent"]
1113
+ image_path = context_img[image_index]["image_path"]
1114
+ if image_index != 0: # use section image other than section 0
1115
+ logger.info(
1116
+ f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}"
1117
+ )
1118
 
1119
+ if not f1_mode:
1120
+ # Inverted Anti-drifting
1121
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
1122
+ (
1123
+ clean_latent_indices_pre,
1124
+ blank_indices,
1125
+ latent_indices,
1126
+ clean_latent_indices_post,
1127
+ clean_latent_2x_indices,
1128
+ clean_latent_4x_indices,
1129
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
1130
+
1131
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
1132
+
1133
+ clean_latents_pre = start_latent.to(history_latents)
1134
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
1135
+ [1, 2, 16], dim=2
1136
+ )
1137
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
1138
 
1139
+ else:
1140
+ # F1 mode
1141
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
1142
+ (
1143
+ clean_latent_indices_start,
1144
+ clean_latent_4x_indices,
1145
+ clean_latent_2x_indices,
1146
+ clean_latent_1x_indices,
1147
+ latent_indices,
1148
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
1149
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
1150
+
1151
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
1152
+ [16, 2, 1], dim=2
1153
+ )
1154
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
1155
+
1156
+ # if use_teacache:
1157
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
1158
+ # else:
1159
+ # transformer.initialize_teacache(enable_teacache=False)
1160
+
1161
+ # prepare conditioning inputs
1162
+ if section_index_from_last in context:
1163
+ prompt_index = section_index_from_last
1164
+ elif section_index in context:
1165
+ prompt_index = section_index
1166
+ else:
1167
+ prompt_index = 0
1168
 
1169
+ context_for_index = context[prompt_index]
1170
+ # if args.section_prompts is not None:
1171
+ logger.info(f"Section {section_index}: {context_for_index['prompt']}")
1172
 
1173
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
1174
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
1175
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1176
 
1177
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
1178
+ device, dtype=torch.bfloat16
1179
  )
1180
 
1181
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
1182
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
1183
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1184
+
1185
+ generated_latents = sample_hunyuan(
1186
+ transformer=model,
1187
+ sampler=args.sample_solver,
1188
+ width=width,
1189
+ height=height,
1190
+ frames=num_frames,
1191
+ real_guidance_scale=args.guidance_scale,
1192
+ distilled_guidance_scale=args.embedded_cfg_scale,
1193
+ guidance_rescale=args.guidance_rescale,
1194
+ # shift=3.0,
1195
+ num_inference_steps=args.infer_steps,
1196
+ generator=seed_g,
1197
+ prompt_embeds=llama_vec,
1198
+ prompt_embeds_mask=llama_attention_mask,
1199
+ prompt_poolers=clip_l_pooler,
1200
+ negative_prompt_embeds=llama_vec_n,
1201
+ negative_prompt_embeds_mask=llama_attention_mask_n,
1202
+ negative_prompt_poolers=clip_l_pooler_n,
1203
+ device=device,
1204
+ dtype=torch.bfloat16,
1205
+ image_embeddings=image_encoder_last_hidden_state,
1206
+ latent_indices=latent_indices,
1207
+ clean_latents=clean_latents,
1208
+ clean_latent_indices=clean_latent_indices,
1209
+ clean_latents_2x=clean_latents_2x,
1210
+ clean_latent_2x_indices=clean_latent_2x_indices,
1211
+ clean_latents_4x=clean_latents_4x,
1212
+ clean_latent_4x_indices=clean_latent_4x_indices,
1213
+ )
 
 
 
 
1214
 
1215
+ # concatenate generated latents
1216
+ total_generated_latent_frames += int(generated_latents.shape[2])
1217
+ if not f1_mode:
1218
+ # Inverted Anti-drifting: prepend generated latents to history latents
1219
+ if is_last_section:
1220
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
1221
+ total_generated_latent_frames += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1222
 
1223
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
1224
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
1225
+ else:
1226
+ # F1 mode: append generated latents to history latents
1227
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
1228
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
1229
+
1230
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
1231
+
1232
+ # # TODO support saving intermediate video
1233
+ # clean_memory_on_device(device)
1234
+ # vae.to(device)
1235
+ # if history_pixels is None:
1236
+ # history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
1237
+ # else:
1238
+ # section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
1239
+ # overlapped_frames = latent_window_size * 4 - 3
1240
+ # current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
1241
+ # history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
1242
+ # vae.to("cpu")
1243
+ # # if not is_last_section:
1244
+ # # # save intermediate video
1245
+ # # save_video(history_pixels[0], args, total_generated_latent_frames)
1246
+ # print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
1247
 
1248
  # Only clean up shared models if they were created within this function
1249
  if shared_models is None:
 
1254
  model.to("cpu")
1255
 
1256
  # wait for 5 seconds until block swap is done
1257
+ if args.blocks_to_swap > 0:
1258
+ logger.info("Waiting for 5 seconds to finish block swap")
1259
+ time.sleep(5)
1260
 
1261
  gc.collect()
1262
  clean_memory_on_device(device)
 
1264
  return vae, real_history_latents
1265
 
1266
 
1267
+ def generate_with_one_frame_inference(
1268
+ args: argparse.Namespace,
1269
+ model: HunyuanVideoTransformer3DModelPacked,
1270
+ context: Dict[int, Dict[str, torch.Tensor]],
1271
+ context_null: Dict[str, torch.Tensor],
1272
+ context_img: Dict[int, Dict[str, torch.Tensor]],
1273
+ control_latents: Optional[List[torch.Tensor]],
1274
+ control_mask_images: Optional[List[Optional[Image.Image]]],
1275
+ latent_window_size: int,
1276
+ height: int,
1277
+ width: int,
1278
+ device: torch.device,
1279
+ seed_g: torch.Generator,
1280
+ one_frame_inference: set[str],
1281
+ ) -> torch.Tensor:
1282
+ # one frame inference
1283
+ sample_num_frames = 1
1284
+ latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
1285
+ latent_indices[:, 0] = latent_window_size # last of latent_window
1286
+
1287
+ def get_latent_mask(mask_image: Image.Image) -> torch.Tensor:
1288
+ if mask_image.mode != "L":
1289
+ mask_image = mask_image.convert("L")
1290
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
1291
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
1292
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
1293
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
1294
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
1295
+ mask_image = mask_image.to(torch.float32)
1296
+ return mask_image
1297
+
1298
+ if control_latents is None or len(control_latents) == 0:
1299
+ logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
1300
+ control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
1301
+
1302
+ if "no_post" not in one_frame_inference:
1303
+ # add zero latents as clean latents post
1304
+ control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
1305
+ logger.info(f"Add zero latents as clean latents post for one frame inference.")
1306
+
1307
+ # kisekaeichi and 1f-mc: both are using control images, but indices are different
1308
+ clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
1309
+ clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
1310
+ if "no_post" not in one_frame_inference:
1311
+ clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
1312
+
1313
+ for i in range(len(control_latents)):
1314
+ mask_image = None
1315
+ if args.control_image_mask_path is not None and i < len(args.control_image_mask_path):
1316
+ mask_image = get_latent_mask(Image.open(args.control_image_mask_path[i]))
1317
+ logger.info(
1318
+ f"Apply mask for clean latents 1x for {i + 1}: {args.control_image_mask_path[i]}, shape: {mask_image.shape}"
1319
+ )
1320
+ elif control_mask_images is not None and i < len(control_mask_images) and control_mask_images[i] is not None:
1321
+ mask_image = get_latent_mask(control_mask_images[i])
1322
+ logger.info(f"Apply mask for clean latents 1x for {i + 1} with alpha channel: {mask_image.shape}")
1323
+ if mask_image is not None:
1324
+ clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * mask_image
1325
+
1326
+ for one_frame_param in one_frame_inference:
1327
+ if one_frame_param.startswith("target_index="):
1328
+ target_index = int(one_frame_param.split("=")[1])
1329
+ latent_indices[:, 0] = target_index
1330
+ logger.info(f"Set index for target: {target_index}")
1331
+ elif one_frame_param.startswith("control_index="):
1332
+ control_indices = one_frame_param.split("=")[1].split(";")
1333
+ i = 0
1334
+ while i < len(control_indices) and i < clean_latent_indices.shape[1]:
1335
+ control_index = int(control_indices[i])
1336
+ clean_latent_indices[:, i] = control_index
1337
+ i += 1
1338
+ logger.info(f"Set index for clean latent 1x: {control_indices}")
1339
+
1340
+ # "default" option does nothing, so we can skip it
1341
+ if "default" in one_frame_inference:
1342
+ pass
1343
+
1344
+ if "no_2x" in one_frame_inference:
1345
+ clean_latents_2x = None
1346
+ clean_latent_2x_indices = None
1347
+ logger.info(f"No clean_latents_2x")
1348
+ else:
1349
+ clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
1350
+ index = 1 + latent_window_size + 1
1351
+ clean_latent_2x_indices = torch.arange(index, index + 2).unsqueeze(0) # 2
1352
+
1353
+ if "no_4x" in one_frame_inference:
1354
+ clean_latents_4x = None
1355
+ clean_latent_4x_indices = None
1356
+ logger.info(f"No clean_latents_4x")
1357
+ else:
1358
+ clean_latents_4x = torch.zeros((1, 16, 16, height // 8, width // 8), dtype=torch.float32)
1359
+ index = 1 + latent_window_size + 1 + 2
1360
+ clean_latent_4x_indices = torch.arange(index, index + 16).unsqueeze(0) # 16
1361
+
1362
+ logger.info(
1363
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
1364
+ )
1365
+
1366
+ # prepare conditioning inputs
1367
+ prompt_index = 0
1368
+ image_index = 0
1369
+
1370
+ context_for_index = context[prompt_index]
1371
+ logger.info(f"Prompt: {context_for_index['prompt']}")
1372
+
1373
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
1374
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
1375
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1376
+
1377
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(device, dtype=torch.bfloat16)
1378
+
1379
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
1380
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
1381
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1382
+
1383
+ generated_latents = sample_hunyuan(
1384
+ transformer=model,
1385
+ sampler=args.sample_solver,
1386
+ width=width,
1387
+ height=height,
1388
+ frames=1,
1389
+ real_guidance_scale=args.guidance_scale,
1390
+ distilled_guidance_scale=args.embedded_cfg_scale,
1391
+ guidance_rescale=args.guidance_rescale,
1392
+ # shift=3.0,
1393
+ num_inference_steps=args.infer_steps,
1394
+ generator=seed_g,
1395
+ prompt_embeds=llama_vec,
1396
+ prompt_embeds_mask=llama_attention_mask,
1397
+ prompt_poolers=clip_l_pooler,
1398
+ negative_prompt_embeds=llama_vec_n,
1399
+ negative_prompt_embeds_mask=llama_attention_mask_n,
1400
+ negative_prompt_poolers=clip_l_pooler_n,
1401
+ device=device,
1402
+ dtype=torch.bfloat16,
1403
+ image_embeddings=image_encoder_last_hidden_state,
1404
+ latent_indices=latent_indices,
1405
+ clean_latents=clean_latents,
1406
+ clean_latent_indices=clean_latent_indices,
1407
+ clean_latents_2x=clean_latents_2x,
1408
+ clean_latent_2x_indices=clean_latent_2x_indices,
1409
+ clean_latents_4x=clean_latents_4x,
1410
+ clean_latent_4x_indices=clean_latent_4x_indices,
1411
+ )
1412
+
1413
+ real_history_latents = generated_latents.to(clean_latents)
1414
+ return real_history_latents
1415
+
1416
+
1417
  def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
1418
  """Save latent to file
1419
 
fpack_train_network.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import math
4
+ import time
5
+ from typing import Optional
6
+ from PIL import Image
7
+
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchvision.transforms.functional as TF
12
+ from tqdm import tqdm
13
+ from accelerate import Accelerator, init_empty_weights
14
+
15
+ from dataset import image_video_dataset
16
+ from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ARCHITECTURE_FRAMEPACK_FULL, load_video
17
+ from fpack_generate_video import decode_latent
18
+ from frame_pack import hunyuan
19
+ from frame_pack.clip_vision import hf_clip_vision_encode
20
+ from frame_pack.framepack_utils import load_image_encoders, load_text_encoder1, load_text_encoder2
21
+ from frame_pack.framepack_utils import load_vae as load_framepack_vae
22
+ from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
23
+ from frame_pack.k_diffusion_hunyuan import sample_hunyuan
24
+ from frame_pack.utils import crop_or_pad_yield_mask
25
+ from dataset.image_video_dataset import resize_image_to_bucket
26
+ from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
27
+
28
+ import logging
29
+
30
+ logger = logging.getLogger(__name__)
31
+ logging.basicConfig(level=logging.INFO)
32
+
33
+ from utils import model_utils
34
+ from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
35
+
36
+
37
+ class FramePackNetworkTrainer(NetworkTrainer):
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ # region model specific
42
+
43
+ @property
44
+ def architecture(self) -> str:
45
+ return ARCHITECTURE_FRAMEPACK
46
+
47
+ @property
48
+ def architecture_full_name(self) -> str:
49
+ return ARCHITECTURE_FRAMEPACK_FULL
50
+
51
+ def handle_model_specific_args(self, args):
52
+ self._i2v_training = True
53
+ self._control_training = False
54
+ self.default_guidance_scale = 10.0 # embeded guidance scale
55
+
56
+ def process_sample_prompts(
57
+ self,
58
+ args: argparse.Namespace,
59
+ accelerator: Accelerator,
60
+ sample_prompts: str,
61
+ ):
62
+ device = accelerator.device
63
+
64
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
65
+ prompts = load_prompts(sample_prompts)
66
+
67
+ # load text encoder
68
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
69
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
70
+ text_encoder2.to(device)
71
+
72
+ sample_prompts_te_outputs = {} # (prompt) -> (t1 embeds, t1 mask, t2 embeds)
73
+ for prompt_dict in prompts:
74
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
75
+ if p is None or p in sample_prompts_te_outputs:
76
+ continue
77
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
78
+ with torch.amp.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
79
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(p, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
80
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
81
+
82
+ llama_vec = llama_vec.to("cpu")
83
+ llama_attention_mask = llama_attention_mask.to("cpu")
84
+ clip_l_pooler = clip_l_pooler.to("cpu")
85
+ sample_prompts_te_outputs[p] = (llama_vec, llama_attention_mask, clip_l_pooler)
86
+ del text_encoder1, text_encoder2
87
+ clean_memory_on_device(device)
88
+
89
+ # image embedding for I2V training
90
+ feature_extractor, image_encoder = load_image_encoders(args)
91
+ image_encoder.to(device)
92
+
93
+ # encode image with image encoder
94
+ sample_prompts_image_embs = {}
95
+ for prompt_dict in prompts:
96
+ image_path = prompt_dict.get("image_path", None)
97
+ assert image_path is not None, "image_path should be set for I2V training"
98
+ if image_path in sample_prompts_image_embs:
99
+ continue
100
+
101
+ logger.info(f"Encoding image to image encoder context: {image_path}")
102
+
103
+ height = prompt_dict.get("height", 256)
104
+ width = prompt_dict.get("width", 256)
105
+
106
+ img = Image.open(image_path).convert("RGB")
107
+ img_np = np.array(img) # PIL to numpy, HWC
108
+ img_np = image_video_dataset.resize_image_to_bucket(img_np, (width, height)) # returns a numpy array
109
+
110
+ with torch.no_grad():
111
+ image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
112
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
113
+
114
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to("cpu")
115
+ sample_prompts_image_embs[image_path] = image_encoder_last_hidden_state
116
+
117
+ del image_encoder
118
+ clean_memory_on_device(device)
119
+
120
+ # prepare sample parameters
121
+ sample_parameters = []
122
+ for prompt_dict in prompts:
123
+ prompt_dict_copy = prompt_dict.copy()
124
+
125
+ p = prompt_dict.get("prompt", "")
126
+ llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
127
+ prompt_dict_copy["llama_vec"] = llama_vec
128
+ prompt_dict_copy["llama_attention_mask"] = llama_attention_mask
129
+ prompt_dict_copy["clip_l_pooler"] = clip_l_pooler
130
+
131
+ p = prompt_dict.get("negative_prompt", "")
132
+ llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
133
+ prompt_dict_copy["negative_llama_vec"] = llama_vec
134
+ prompt_dict_copy["negative_llama_attention_mask"] = llama_attention_mask
135
+ prompt_dict_copy["negative_clip_l_pooler"] = clip_l_pooler
136
+
137
+ p = prompt_dict.get("image_path", None)
138
+ prompt_dict_copy["image_encoder_last_hidden_state"] = sample_prompts_image_embs[p]
139
+
140
+ sample_parameters.append(prompt_dict_copy)
141
+
142
+ clean_memory_on_device(accelerator.device)
143
+ return sample_parameters
144
+
145
+ def do_inference(
146
+ self,
147
+ accelerator,
148
+ args,
149
+ sample_parameter,
150
+ vae,
151
+ dit_dtype,
152
+ transformer,
153
+ discrete_flow_shift,
154
+ sample_steps,
155
+ width,
156
+ height,
157
+ frame_count,
158
+ generator,
159
+ do_classifier_free_guidance,
160
+ guidance_scale,
161
+ cfg_scale,
162
+ image_path=None,
163
+ control_video_path=None,
164
+ ):
165
+ """architecture dependent inference"""
166
+ model: HunyuanVideoTransformer3DModelPacked = transformer
167
+ device = accelerator.device
168
+ if cfg_scale is None:
169
+ cfg_scale = 1.0
170
+ do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
171
+
172
+ # prepare parameters
173
+ one_frame_mode = args.one_frame
174
+ if one_frame_mode:
175
+ one_frame_inference = set()
176
+ for mode in sample_parameter["one_frame"].split(","):
177
+ one_frame_inference.add(mode.strip())
178
+ else:
179
+ one_frame_inference = None
180
+
181
+ latent_window_size = args.latent_window_size # default is 9
182
+ latent_f = (frame_count - 1) // 4 + 1
183
+ total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
184
+ if total_latent_sections < 1 and not one_frame_mode:
185
+ logger.warning(f"Not enough frames for FramePack: {latent_f}, minimum: {latent_window_size*4+1}")
186
+ return None
187
+
188
+ latent_f = total_latent_sections * latent_window_size + 1
189
+ actual_frame_count = (latent_f - 1) * 4 + 1
190
+ if actual_frame_count != frame_count:
191
+ logger.info(f"Frame count mismatch: {actual_frame_count} != {frame_count}, trimming to {actual_frame_count}")
192
+ frame_count = actual_frame_count
193
+ num_frames = latent_window_size * 4 - 3
194
+
195
+ # prepare start and control latent
196
+ def encode_image(path):
197
+ image = Image.open(path)
198
+ if image.mode == "RGBA":
199
+ alpha = image.split()[-1]
200
+ image = image.convert("RGB")
201
+ else:
202
+ alpha = None
203
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
204
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).unsqueeze(0).float() # 1, C, 1, H, W
205
+ image = image / 127.5 - 1 # -1 to 1
206
+ return hunyuan.vae_encode(image, vae).to("cpu"), alpha
207
+
208
+ # VAE encoding
209
+ logger.info(f"Encoding image to latent space")
210
+ vae.to(device)
211
+
212
+ start_latent, _ = (
213
+ encode_image(image_path) if image_path else torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32)
214
+ )
215
+
216
+ if one_frame_mode:
217
+ control_latents = []
218
+ control_alphas = []
219
+ if "control_image_path" in sample_parameter:
220
+ for control_image_path in sample_parameter["control_image_path"]:
221
+ control_latent, control_alpha = encode_image(control_image_path)
222
+ control_latents.append(control_latent)
223
+ control_alphas.append(control_alpha)
224
+ else:
225
+ control_latents = None
226
+ control_alphas = None
227
+
228
+ vae.to("cpu") # move VAE to CPU to save memory
229
+ clean_memory_on_device(device)
230
+
231
+ # sampilng
232
+ if not one_frame_mode:
233
+ f1_mode = args.f1
234
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
235
+
236
+ if not f1_mode:
237
+ total_generated_latent_frames = 0
238
+ latent_paddings = reversed(range(total_latent_sections))
239
+ else:
240
+ total_generated_latent_frames = 1
241
+ history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
242
+ latent_paddings = [0] * total_latent_sections
243
+
244
+ if total_latent_sections > 4:
245
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
246
+
247
+ latent_paddings = list(latent_paddings)
248
+ for loop_index in range(total_latent_sections):
249
+ latent_padding = latent_paddings[loop_index]
250
+
251
+ if not f1_mode:
252
+ is_last_section = latent_padding == 0
253
+ latent_padding_size = latent_padding * latent_window_size
254
+
255
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
256
+
257
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
258
+ (
259
+ clean_latent_indices_pre,
260
+ blank_indices,
261
+ latent_indices,
262
+ clean_latent_indices_post,
263
+ clean_latent_2x_indices,
264
+ clean_latent_4x_indices,
265
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
266
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
267
+
268
+ clean_latents_pre = start_latent.to(history_latents)
269
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
270
+ [1, 2, 16], dim=2
271
+ )
272
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
273
+ else:
274
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
275
+ (
276
+ clean_latent_indices_start,
277
+ clean_latent_4x_indices,
278
+ clean_latent_2x_indices,
279
+ clean_latent_1x_indices,
280
+ latent_indices,
281
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
282
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
283
+
284
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
285
+ [16, 2, 1], dim=2
286
+ )
287
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
288
+
289
+ # if use_teacache:
290
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
291
+ # else:
292
+ # transformer.initialize_teacache(enable_teacache=False)
293
+
294
+ llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
295
+ llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
296
+ clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
297
+ if cfg_scale == 1.0:
298
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
299
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
300
+ else:
301
+ llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
302
+ llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
303
+ clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
304
+ image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
305
+ device, dtype=torch.bfloat16
306
+ )
307
+
308
+ generated_latents = sample_hunyuan(
309
+ transformer=model,
310
+ sampler=args.sample_solver,
311
+ width=width,
312
+ height=height,
313
+ frames=num_frames,
314
+ real_guidance_scale=cfg_scale,
315
+ distilled_guidance_scale=guidance_scale,
316
+ guidance_rescale=0.0,
317
+ # shift=3.0,
318
+ num_inference_steps=sample_steps,
319
+ generator=generator,
320
+ prompt_embeds=llama_vec,
321
+ prompt_embeds_mask=llama_attention_mask,
322
+ prompt_poolers=clip_l_pooler,
323
+ negative_prompt_embeds=llama_vec_n,
324
+ negative_prompt_embeds_mask=llama_attention_mask_n,
325
+ negative_prompt_poolers=clip_l_pooler_n,
326
+ device=device,
327
+ dtype=torch.bfloat16,
328
+ image_embeddings=image_encoder_last_hidden_state,
329
+ latent_indices=latent_indices,
330
+ clean_latents=clean_latents,
331
+ clean_latent_indices=clean_latent_indices,
332
+ clean_latents_2x=clean_latents_2x,
333
+ clean_latent_2x_indices=clean_latent_2x_indices,
334
+ clean_latents_4x=clean_latents_4x,
335
+ clean_latent_4x_indices=clean_latent_4x_indices,
336
+ )
337
+
338
+ total_generated_latent_frames += int(generated_latents.shape[2])
339
+ if not f1_mode:
340
+ if is_last_section:
341
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
342
+ total_generated_latent_frames += 1
343
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
344
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
345
+ else:
346
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
347
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
348
+
349
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
350
+ else:
351
+ # one frame mode
352
+ sample_num_frames = 1
353
+ latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
354
+ latent_indices[:, 0] = latent_window_size # last of latent_window
355
+
356
+ def get_latent_mask(mask_image: Image.Image):
357
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
358
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
359
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
360
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
361
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (B, C, F, H, W)
362
+ mask_image = mask_image.to(torch.float32)
363
+ return mask_image
364
+
365
+ if control_latents is None or len(control_latents) == 0:
366
+ logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
367
+ control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
368
+
369
+ if "no_post" not in one_frame_inference:
370
+ # add zero latents as clean latents post
371
+ control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
372
+ logger.info(f"Add zero latents as clean latents post for one frame inference.")
373
+
374
+ # kisekaeichi and 1f-mc: both are using control images, but indices are different
375
+ clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
376
+ clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
377
+ if "no_post" not in one_frame_inference:
378
+ clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
379
+
380
+ # apply mask for control latents (clean latents)
381
+ for i in range(len(control_alphas)):
382
+ control_alpha = control_alphas[i]
383
+ if control_alpha is not None:
384
+ latent_mask = get_latent_mask(control_alpha)
385
+ logger.info(
386
+ f"Apply mask for clean latents 1x for {i+1}: shape: {latent_mask.shape}"
387
+ )
388
+ clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * latent_mask
389
+
390
+ for one_frame_param in one_frame_inference:
391
+ if one_frame_param.startswith("target_index="):
392
+ target_index = int(one_frame_param.split("=")[1])
393
+ latent_indices[:, 0] = target_index
394
+ logger.info(f"Set index for target: {target_index}")
395
+ elif one_frame_param.startswith("control_index="):
396
+ control_indices = one_frame_param.split("=")[1].split(";")
397
+ i = 0
398
+ while i < len(control_indices) and i < clean_latent_indices.shape[1]:
399
+ control_index = int(control_indices[i])
400
+ clean_latent_indices[:, i] = control_index
401
+ i += 1
402
+ logger.info(f"Set index for clean latent 1x: {control_indices}")
403
+
404
+ if "no_2x" in one_frame_inference:
405
+ clean_latents_2x = None
406
+ clean_latent_2x_indices = None
407
+ logger.info(f"No clean_latents_2x")
408
+ else:
409
+ clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
410
+ index = 1 + latent_window_size + 1
411
+ clean_latent_2x_indices = torch.arange(index, index + 2) # 2
412
+
413
+ if "no_4x" in one_frame_inference:
414
+ clean_latents_4x = None
415
+ clean_latent_4x_indices = None
416
+ logger.info(f"No clean_latents_4x")
417
+ else:
418
+ index = 1 + latent_window_size + 1 + 2
419
+ clean_latent_4x_indices = torch.arange(index, index + 16) # 16
420
+
421
+ logger.info(
422
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
423
+ )
424
+
425
+ # prepare conditioning inputs
426
+ llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
427
+ llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
428
+ clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
429
+ if cfg_scale == 1.0:
430
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
431
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
432
+ else:
433
+ llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
434
+ llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
435
+ clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
436
+ image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
437
+ device, dtype=torch.bfloat16
438
+ )
439
+
440
+ generated_latents = sample_hunyuan(
441
+ transformer=model,
442
+ sampler=args.sample_solver,
443
+ width=width,
444
+ height=height,
445
+ frames=1,
446
+ real_guidance_scale=cfg_scale,
447
+ distilled_guidance_scale=guidance_scale,
448
+ guidance_rescale=0.0,
449
+ # shift=3.0,
450
+ num_inference_steps=sample_steps,
451
+ generator=generator,
452
+ prompt_embeds=llama_vec,
453
+ prompt_embeds_mask=llama_attention_mask,
454
+ prompt_poolers=clip_l_pooler,
455
+ negative_prompt_embeds=llama_vec_n,
456
+ negative_prompt_embeds_mask=llama_attention_mask_n,
457
+ negative_prompt_poolers=clip_l_pooler_n,
458
+ device=device,
459
+ dtype=torch.bfloat16,
460
+ image_embeddings=image_encoder_last_hidden_state,
461
+ latent_indices=latent_indices,
462
+ clean_latents=clean_latents,
463
+ clean_latent_indices=clean_latent_indices,
464
+ clean_latents_2x=clean_latents_2x,
465
+ clean_latent_2x_indices=clean_latent_2x_indices,
466
+ clean_latents_4x=clean_latents_4x,
467
+ clean_latent_4x_indices=clean_latent_4x_indices,
468
+ )
469
+
470
+ real_history_latents = generated_latents.to(clean_latents)
471
+
472
+ # wait for 5 seconds until block swap is done
473
+ logger.info("Waiting for 5 seconds to finish block swap")
474
+ time.sleep(5)
475
+
476
+ gc.collect()
477
+ clean_memory_on_device(device)
478
+
479
+ video = decode_latent(
480
+ latent_window_size, total_latent_sections, args.bulk_decode, vae, real_history_latents, device, one_frame_mode
481
+ )
482
+ video = video.to("cpu", dtype=torch.float32).unsqueeze(0) # add batch dimension
483
+ video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
484
+ clean_memory_on_device(device)
485
+
486
+ return video
487
+
488
+ def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
489
+ vae_path = args.vae
490
+ logger.info(f"Loading VAE model from {vae_path}")
491
+ vae = load_framepack_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
492
+ return vae
493
+
494
+ def load_transformer(
495
+ self,
496
+ accelerator: Accelerator,
497
+ args: argparse.Namespace,
498
+ dit_path: str,
499
+ attn_mode: str,
500
+ split_attn: bool,
501
+ loading_device: str,
502
+ dit_weight_dtype: Optional[torch.dtype],
503
+ ):
504
+ logger.info(f"Loading DiT model from {dit_path}")
505
+ device = accelerator.device
506
+ model = load_packed_model(device, dit_path, attn_mode, loading_device, args.fp8_scaled, split_attn)
507
+ return model
508
+
509
+ def scale_shift_latents(self, latents):
510
+ # FramePack VAE includes scaling
511
+ return latents
512
+
513
+ def call_dit(
514
+ self,
515
+ args: argparse.Namespace,
516
+ accelerator: Accelerator,
517
+ transformer,
518
+ latents: torch.Tensor,
519
+ batch: dict[str, torch.Tensor],
520
+ noise: torch.Tensor,
521
+ noisy_model_input: torch.Tensor,
522
+ timesteps: torch.Tensor,
523
+ network_dtype: torch.dtype,
524
+ ):
525
+ model: HunyuanVideoTransformer3DModelPacked = transformer
526
+ device = accelerator.device
527
+ batch_size = latents.shape[0]
528
+
529
+ # maybe model.dtype is better than network_dtype...
530
+ distilled_guidance = torch.tensor([args.guidance_scale * 1000.0] * batch_size).to(device=device, dtype=network_dtype)
531
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
532
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
533
+ # for k, v in batch.items():
534
+ # if isinstance(v, torch.Tensor):
535
+ # print(f"{k}: {v.shape} {v.dtype} {v.device}")
536
+ with accelerator.autocast():
537
+ clean_latent_2x_indices = batch["clean_latent_2x_indices"] if "clean_latent_2x_indices" in batch else None
538
+ if clean_latent_2x_indices is not None:
539
+ clean_latent_2x = batch["latents_clean_2x"] if "latents_clean_2x" in batch else None
540
+ if clean_latent_2x is None:
541
+ clean_latent_2x = torch.zeros(
542
+ (batch_size, 16, 2, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
543
+ )
544
+ else:
545
+ clean_latent_2x = None
546
+
547
+ clean_latent_4x_indices = batch["clean_latent_4x_indices"] if "clean_latent_4x_indices" in batch else None
548
+ if clean_latent_4x_indices is not None:
549
+ clean_latent_4x = batch["latents_clean_4x"] if "latents_clean_4x" in batch else None
550
+ if clean_latent_4x is None:
551
+ clean_latent_4x = torch.zeros(
552
+ (batch_size, 16, 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
553
+ )
554
+ else:
555
+ clean_latent_4x = None
556
+
557
+ model_pred = model(
558
+ hidden_states=noisy_model_input,
559
+ timestep=timesteps,
560
+ encoder_hidden_states=batch["llama_vec"],
561
+ encoder_attention_mask=batch["llama_attention_mask"],
562
+ pooled_projections=batch["clip_l_pooler"],
563
+ guidance=distilled_guidance,
564
+ latent_indices=batch["latent_indices"],
565
+ clean_latents=batch["latents_clean"],
566
+ clean_latent_indices=batch["clean_latent_indices"],
567
+ clean_latents_2x=clean_latent_2x,
568
+ clean_latent_2x_indices=clean_latent_2x_indices,
569
+ clean_latents_4x=clean_latent_4x,
570
+ clean_latent_4x_indices=clean_latent_4x_indices,
571
+ image_embeddings=batch["image_embeddings"],
572
+ return_dict=False,
573
+ )
574
+ model_pred = model_pred[0] # returns tuple (model_pred, )
575
+
576
+ # flow matching loss
577
+ target = noise - latents
578
+
579
+ return model_pred, target
580
+
581
+ # endregion model specific
582
+
583
+
584
+ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
585
+ """FramePack specific parser setup"""
586
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
587
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
588
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
589
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
590
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
591
+ parser.add_argument(
592
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
593
+ )
594
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
595
+ parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
596
+ parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once in sample generation")
597
+ parser.add_argument("--f1", action="store_true", help="Use F1 sampling method for sample generation")
598
+ parser.add_argument("--one_frame", action="store_true", help="Use one frame sampling method for sample generation")
599
+ return parser
600
+
601
+
602
+ if __name__ == "__main__":
603
+ parser = setup_parser_common()
604
+ parser = framepack_setup_parser(parser)
605
+
606
+ args = parser.parse_args()
607
+ args = read_config_from_file(args, parser)
608
+
609
+ assert (
610
+ args.vae_dtype is None or args.vae_dtype == "float16"
611
+ ), "VAE dtype must be float16 / VAEのdtypeはfloat16でなければなりません"
612
+ args.vae_dtype = "float16" # fixed
613
+ args.dit_dtype = "bfloat16" # fixed
614
+ args.sample_solver = "unipc" # for sample generation, fixed to unipc
615
+
616
+ trainer = FramePackNetworkTrainer()
617
+ trainer.train(args)
hv_train.py ADDED
@@ -0,0 +1,1721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import asyncio
3
+ from datetime import timedelta
4
+ import gc
5
+ import importlib
6
+ import argparse
7
+ import math
8
+ import os
9
+ import pathlib
10
+ import re
11
+ import sys
12
+ import random
13
+ import time
14
+ import json
15
+ from multiprocessing import Value
16
+ from typing import Any, Dict, List, Optional
17
+ import accelerate
18
+ import numpy as np
19
+ from packaging.version import Version
20
+
21
+ import huggingface_hub
22
+ import toml
23
+
24
+ import torch
25
+ from tqdm import tqdm
26
+ from accelerate.utils import set_seed
27
+ from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
28
+ from safetensors.torch import load_file, save_file
29
+ import transformers
30
+ from diffusers.optimization import (
31
+ SchedulerType as DiffusersSchedulerType,
32
+ TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
33
+ )
34
+ from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
+
36
+ from dataset import config_utils
37
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
38
+ import hunyuan_model.text_encoder as text_encoder_module
39
+ from hunyuan_model.vae import load_vae
40
+ import hunyuan_model.vae as vae_module
41
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
42
+ import networks.lora as lora_module
43
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
44
+ from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO
45
+
46
+ import logging
47
+
48
+ from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
49
+
50
+ logger = logging.getLogger(__name__)
51
+ logging.basicConfig(level=logging.INFO)
52
+
53
+
54
+ BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
55
+
56
+ # TODO make separate file for some functions to commonize with other scripts
57
+
58
+
59
+ def clean_memory_on_device(device: torch.device):
60
+ r"""
61
+ Clean memory on the specified device, will be called from training scripts.
62
+ """
63
+ gc.collect()
64
+
65
+ # device may "cuda" or "cuda:0", so we need to check the type of device
66
+ if device.type == "cuda":
67
+ torch.cuda.empty_cache()
68
+ if device.type == "xpu":
69
+ torch.xpu.empty_cache()
70
+ if device.type == "mps":
71
+ torch.mps.empty_cache()
72
+
73
+
74
+ # for collate_fn: epoch and step is multiprocessing.Value
75
+ class collator_class:
76
+ def __init__(self, epoch, step, dataset):
77
+ self.current_epoch = epoch
78
+ self.current_step = step
79
+ self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
80
+
81
+ def __call__(self, examples):
82
+ worker_info = torch.utils.data.get_worker_info()
83
+ # worker_info is None in the main process
84
+ if worker_info is not None:
85
+ dataset = worker_info.dataset
86
+ else:
87
+ dataset = self.dataset
88
+
89
+ # set epoch and step
90
+ dataset.set_current_epoch(self.current_epoch.value)
91
+ dataset.set_current_step(self.current_step.value)
92
+ return examples[0]
93
+
94
+
95
+ def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
96
+ """
97
+ DeepSpeed is not supported in this script currently.
98
+ """
99
+ if args.logging_dir is None:
100
+ logging_dir = None
101
+ else:
102
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
103
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
104
+
105
+ if args.log_with is None:
106
+ if logging_dir is not None:
107
+ log_with = "tensorboard"
108
+ else:
109
+ log_with = None
110
+ else:
111
+ log_with = args.log_with
112
+ if log_with in ["tensorboard", "all"]:
113
+ if logging_dir is None:
114
+ raise ValueError(
115
+ "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
116
+ )
117
+ if log_with in ["wandb", "all"]:
118
+ try:
119
+ import wandb
120
+ except ImportError:
121
+ raise ImportError("No wandb / wandb がインストールされていないようです")
122
+ if logging_dir is not None:
123
+ os.makedirs(logging_dir, exist_ok=True)
124
+ os.environ["WANDB_DIR"] = logging_dir
125
+ if args.wandb_api_key is not None:
126
+ wandb.login(key=args.wandb_api_key)
127
+
128
+ kwargs_handlers = [
129
+ (
130
+ InitProcessGroupKwargs(
131
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
132
+ init_method=(
133
+ "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
134
+ ),
135
+ timeout=timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
136
+ )
137
+ if torch.cuda.device_count() > 1
138
+ else None
139
+ ),
140
+ (
141
+ DistributedDataParallelKwargs(
142
+ gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
143
+ )
144
+ if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
145
+ else None
146
+ ),
147
+ ]
148
+ kwargs_handlers = [i for i in kwargs_handlers if i is not None]
149
+
150
+ accelerator = Accelerator(
151
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
152
+ mixed_precision=args.mixed_precision,
153
+ log_with=log_with,
154
+ project_dir=logging_dir,
155
+ kwargs_handlers=kwargs_handlers,
156
+ )
157
+ print("accelerator device:", accelerator.device)
158
+ return accelerator
159
+
160
+
161
+ def line_to_prompt_dict(line: str) -> dict:
162
+ # subset of gen_img_diffusers
163
+ prompt_args = line.split(" --")
164
+ prompt_dict = {}
165
+ prompt_dict["prompt"] = prompt_args[0]
166
+
167
+ for parg in prompt_args:
168
+ try:
169
+ m = re.match(r"w (\d+)", parg, re.IGNORECASE)
170
+ if m:
171
+ prompt_dict["width"] = int(m.group(1))
172
+ continue
173
+
174
+ m = re.match(r"h (\d+)", parg, re.IGNORECASE)
175
+ if m:
176
+ prompt_dict["height"] = int(m.group(1))
177
+ continue
178
+
179
+ m = re.match(r"f (\d+)", parg, re.IGNORECASE)
180
+ if m:
181
+ prompt_dict["frame_count"] = int(m.group(1))
182
+ continue
183
+
184
+ m = re.match(r"d (\d+)", parg, re.IGNORECASE)
185
+ if m:
186
+ prompt_dict["seed"] = int(m.group(1))
187
+ continue
188
+
189
+ m = re.match(r"s (\d+)", parg, re.IGNORECASE)
190
+ if m: # steps
191
+ prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
192
+ continue
193
+
194
+ # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
195
+ # if m: # scale
196
+ # prompt_dict["scale"] = float(m.group(1))
197
+ # continue
198
+ # m = re.match(r"n (.+)", parg, re.IGNORECASE)
199
+ # if m: # negative prompt
200
+ # prompt_dict["negative_prompt"] = m.group(1)
201
+ # continue
202
+
203
+ except ValueError as ex:
204
+ logger.error(f"Exception in parsing / 解析エラー: {parg}")
205
+ logger.error(ex)
206
+
207
+ return prompt_dict
208
+
209
+
210
+ def load_prompts(prompt_file: str) -> list[Dict]:
211
+ # read prompts
212
+ if prompt_file.endswith(".txt"):
213
+ with open(prompt_file, "r", encoding="utf-8") as f:
214
+ lines = f.readlines()
215
+ prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
216
+ elif prompt_file.endswith(".toml"):
217
+ with open(prompt_file, "r", encoding="utf-8") as f:
218
+ data = toml.load(f)
219
+ prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
220
+ elif prompt_file.endswith(".json"):
221
+ with open(prompt_file, "r", encoding="utf-8") as f:
222
+ prompts = json.load(f)
223
+
224
+ # preprocess prompts
225
+ for i in range(len(prompts)):
226
+ prompt_dict = prompts[i]
227
+ if isinstance(prompt_dict, str):
228
+ prompt_dict = line_to_prompt_dict(prompt_dict)
229
+ prompts[i] = prompt_dict
230
+ assert isinstance(prompt_dict, dict)
231
+
232
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
233
+ prompt_dict["enum"] = i
234
+ prompt_dict.pop("subset", None)
235
+
236
+ return prompts
237
+
238
+
239
+ def compute_density_for_timestep_sampling(
240
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
241
+ ):
242
+ """Compute the density for sampling the timesteps when doing SD3 training.
243
+
244
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
245
+
246
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
247
+ """
248
+ if weighting_scheme == "logit_normal":
249
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
250
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
251
+ u = torch.nn.functional.sigmoid(u)
252
+ elif weighting_scheme == "mode":
253
+ u = torch.rand(size=(batch_size,), device="cpu")
254
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
255
+ else:
256
+ u = torch.rand(size=(batch_size,), device="cpu")
257
+ return u
258
+
259
+
260
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
261
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
262
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
263
+ timesteps = timesteps.to(device)
264
+
265
+ # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
266
+ if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
267
+ # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
268
+ # round to nearest timestep
269
+ logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
270
+ step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
271
+ else:
272
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
273
+
274
+ sigma = sigmas[step_indices].flatten()
275
+ while len(sigma.shape) < n_dim:
276
+ sigma = sigma.unsqueeze(-1)
277
+ return sigma
278
+
279
+
280
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
281
+ """Computes loss weighting scheme for SD3 training.
282
+
283
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
284
+
285
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
286
+ """
287
+ if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
288
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
289
+ if weighting_scheme == "sigma_sqrt":
290
+ weighting = (sigmas**-2.0).float()
291
+ else:
292
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
293
+ weighting = 2 / (math.pi * bot)
294
+ else:
295
+ weighting = None # torch.ones_like(sigmas)
296
+ return weighting
297
+
298
+
299
+ class FineTuningTrainer:
300
+ def __init__(self):
301
+ pass
302
+
303
+ def process_sample_prompts(
304
+ self,
305
+ args: argparse.Namespace,
306
+ accelerator: Accelerator,
307
+ sample_prompts: str,
308
+ text_encoder1: str,
309
+ text_encoder2: str,
310
+ fp8_llm: bool,
311
+ ):
312
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
313
+ prompts = load_prompts(sample_prompts)
314
+
315
+ def encode_for_text_encoder(text_encoder, is_llm=True):
316
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
317
+ with accelerator.autocast(), torch.no_grad():
318
+ for prompt_dict in prompts:
319
+ for p in [prompt_dict.get("prompt", "")]:
320
+ if p not in sample_prompts_te_outputs:
321
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
322
+
323
+ data_type = "video"
324
+ text_inputs = text_encoder.text2tokens(p, data_type=data_type)
325
+
326
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
327
+ sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
328
+
329
+ return sample_prompts_te_outputs
330
+
331
+ # Load Text Encoder 1 and encode
332
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
333
+ logger.info(f"loading text encoder 1: {text_encoder1}")
334
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
335
+
336
+ logger.info("encoding with Text Encoder 1")
337
+ te_outputs_1 = encode_for_text_encoder(text_encoder_1)
338
+ del text_encoder_1
339
+
340
+ # Load Text Encoder 2 and encode
341
+ logger.info(f"loading text encoder 2: {text_encoder2}")
342
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
343
+
344
+ logger.info("encoding with Text Encoder 2")
345
+ te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
346
+ del text_encoder_2
347
+
348
+ # prepare sample parameters
349
+ sample_parameters = []
350
+ for prompt_dict in prompts:
351
+ prompt_dict_copy = prompt_dict.copy()
352
+ p = prompt_dict.get("prompt", "")
353
+ prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
354
+ prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
355
+ prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
356
+ prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
357
+ sample_parameters.append(prompt_dict_copy)
358
+
359
+ clean_memory_on_device(accelerator.device)
360
+
361
+ return sample_parameters
362
+
363
+ def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
364
+ # adamw, adamw8bit, adafactor
365
+
366
+ optimizer_type = args.optimizer_type.lower()
367
+
368
+ # split optimizer_type and optimizer_args
369
+ optimizer_kwargs = {}
370
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
371
+ for arg in args.optimizer_args:
372
+ key, value = arg.split("=")
373
+ value = ast.literal_eval(value)
374
+ optimizer_kwargs[key] = value
375
+
376
+ lr = args.learning_rate
377
+ optimizer = None
378
+ optimizer_class = None
379
+
380
+ if optimizer_type.endswith("8bit".lower()):
381
+ try:
382
+ import bitsandbytes as bnb
383
+ except ImportError:
384
+ raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
385
+
386
+ if optimizer_type == "AdamW8bit".lower():
387
+ logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
388
+ optimizer_class = bnb.optim.AdamW8bit
389
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
390
+
391
+ elif optimizer_type == "Adafactor".lower():
392
+ # Adafactor: check relative_step and warmup_init
393
+ if "relative_step" not in optimizer_kwargs:
394
+ optimizer_kwargs["relative_step"] = True # default
395
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
396
+ logger.info(
397
+ f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
398
+ )
399
+ optimizer_kwargs["relative_step"] = True
400
+ logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
401
+
402
+ if optimizer_kwargs["relative_step"]:
403
+ logger.info(f"relative_step is true / relative_stepがtrueです")
404
+ if lr != 0.0:
405
+ logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
406
+ args.learning_rate = None
407
+
408
+ if args.lr_scheduler != "adafactor":
409
+ logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
410
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
411
+
412
+ lr = None
413
+ else:
414
+ if args.max_grad_norm != 0.0:
415
+ logger.warning(
416
+ f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
417
+ )
418
+ if args.lr_scheduler != "constant_with_warmup":
419
+ logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
420
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
421
+ logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
422
+
423
+ optimizer_class = transformers.optimization.Adafactor
424
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
425
+
426
+ elif optimizer_type == "AdamW".lower():
427
+ logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
428
+ optimizer_class = torch.optim.AdamW
429
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
430
+
431
+ if optimizer is None:
432
+ # 任意のoptimizerを使う
433
+ case_sensitive_optimizer_type = args.optimizer_type # not lower
434
+ logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
435
+
436
+ if "." not in case_sensitive_optimizer_type: # from torch.optim
437
+ optimizer_module = torch.optim
438
+ else: # from other library
439
+ values = case_sensitive_optimizer_type.split(".")
440
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
441
+ case_sensitive_optimizer_type = values[-1]
442
+
443
+ optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
444
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
445
+
446
+ # for logging
447
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
448
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
449
+
450
+ # get train and eval functions
451
+ if hasattr(optimizer, "train") and callable(optimizer.train):
452
+ train_fn = optimizer.train
453
+ eval_fn = optimizer.eval
454
+ else:
455
+ train_fn = lambda: None
456
+ eval_fn = lambda: None
457
+
458
+ return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
459
+
460
+ def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
461
+ return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
462
+
463
+ def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
464
+ # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
465
+ # this scheduler is used for logging only.
466
+ # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
467
+ class DummyScheduler:
468
+ def __init__(self, optimizer: torch.optim.Optimizer):
469
+ self.optimizer = optimizer
470
+
471
+ def step(self):
472
+ pass
473
+
474
+ def get_last_lr(self):
475
+ return [group["lr"] for group in self.optimizer.param_groups]
476
+
477
+ return DummyScheduler(optimizer)
478
+
479
+ def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
480
+ """
481
+ Unified API to get any scheduler from its name.
482
+ """
483
+ # if schedulefree optimizer, return dummy scheduler
484
+ if self.is_schedulefree_optimizer(optimizer, args):
485
+ return self.get_dummy_scheduler(optimizer)
486
+
487
+ name = args.lr_scheduler
488
+ num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
489
+ num_warmup_steps: Optional[int] = (
490
+ int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
491
+ )
492
+ num_decay_steps: Optional[int] = (
493
+ int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
494
+ )
495
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
496
+ num_cycles = args.lr_scheduler_num_cycles
497
+ power = args.lr_scheduler_power
498
+ timescale = args.lr_scheduler_timescale
499
+ min_lr_ratio = args.lr_scheduler_min_lr_ratio
500
+
501
+ lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
502
+ if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
503
+ for arg in args.lr_scheduler_args:
504
+ key, value = arg.split("=")
505
+ value = ast.literal_eval(value)
506
+ lr_scheduler_kwargs[key] = value
507
+
508
+ def wrap_check_needless_num_warmup_steps(return_vals):
509
+ if num_warmup_steps is not None and num_warmup_steps != 0:
510
+ raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
511
+ return return_vals
512
+
513
+ # using any lr_scheduler from other library
514
+ if args.lr_scheduler_type:
515
+ lr_scheduler_type = args.lr_scheduler_type
516
+ logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
517
+ if "." not in lr_scheduler_type: # default to use torch.optim
518
+ lr_scheduler_module = torch.optim.lr_scheduler
519
+ else:
520
+ values = lr_scheduler_type.split(".")
521
+ lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
522
+ lr_scheduler_type = values[-1]
523
+ lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
524
+ lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
525
+ return lr_scheduler
526
+
527
+ if name.startswith("adafactor"):
528
+ assert (
529
+ type(optimizer) == transformers.optimization.Adafactor
530
+ ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
531
+ initial_lr = float(name.split(":")[1])
532
+ # logger.info(f"adafactor scheduler init lr {initial_lr}")
533
+ return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
534
+
535
+ if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
536
+ name = DiffusersSchedulerType(name)
537
+ schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
538
+ return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
539
+
540
+ name = SchedulerType(name)
541
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
542
+
543
+ if name == SchedulerType.CONSTANT:
544
+ return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
545
+
546
+ # All other schedulers require `num_warmup_steps`
547
+ if num_warmup_steps is None:
548
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
549
+
550
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
551
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
552
+
553
+ if name == SchedulerType.INVERSE_SQRT:
554
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
555
+
556
+ # All other schedulers require `num_training_steps`
557
+ if num_training_steps is None:
558
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
559
+
560
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
561
+ return schedule_func(
562
+ optimizer,
563
+ num_warmup_steps=num_warmup_steps,
564
+ num_training_steps=num_training_steps,
565
+ num_cycles=num_cycles,
566
+ **lr_scheduler_kwargs,
567
+ )
568
+
569
+ if name == SchedulerType.POLYNOMIAL:
570
+ return schedule_func(
571
+ optimizer,
572
+ num_warmup_steps=num_warmup_steps,
573
+ num_training_steps=num_training_steps,
574
+ power=power,
575
+ **lr_scheduler_kwargs,
576
+ )
577
+
578
+ if name == SchedulerType.COSINE_WITH_MIN_LR:
579
+ return schedule_func(
580
+ optimizer,
581
+ num_warmup_steps=num_warmup_steps,
582
+ num_training_steps=num_training_steps,
583
+ num_cycles=num_cycles / 2,
584
+ min_lr_rate=min_lr_ratio,
585
+ **lr_scheduler_kwargs,
586
+ )
587
+
588
+ # these schedulers do not require `num_decay_steps`
589
+ if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
590
+ return schedule_func(
591
+ optimizer,
592
+ num_warmup_steps=num_warmup_steps,
593
+ num_training_steps=num_training_steps,
594
+ **lr_scheduler_kwargs,
595
+ )
596
+
597
+ # All other schedulers require `num_decay_steps`
598
+ if num_decay_steps is None:
599
+ raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
600
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
601
+ return schedule_func(
602
+ optimizer,
603
+ num_warmup_steps=num_warmup_steps,
604
+ num_stable_steps=num_stable_steps,
605
+ num_decay_steps=num_decay_steps,
606
+ num_cycles=num_cycles / 2,
607
+ min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
608
+ **lr_scheduler_kwargs,
609
+ )
610
+
611
+ return schedule_func(
612
+ optimizer,
613
+ num_warmup_steps=num_warmup_steps,
614
+ num_training_steps=num_training_steps,
615
+ num_decay_steps=num_decay_steps,
616
+ **lr_scheduler_kwargs,
617
+ )
618
+
619
+ def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
620
+ if not args.resume:
621
+ return False
622
+
623
+ if not args.resume_from_huggingface:
624
+ logger.info(f"resume training from local state: {args.resume}")
625
+ accelerator.load_state(args.resume)
626
+ return True
627
+
628
+ logger.info(f"resume training from huggingface state: {args.resume}")
629
+ repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
630
+ path_in_repo = "/".join(args.resume.split("/")[2:])
631
+ revision = None
632
+ repo_type = None
633
+ if ":" in path_in_repo:
634
+ divided = path_in_repo.split(":")
635
+ if len(divided) == 2:
636
+ path_in_repo, revision = divided
637
+ repo_type = "model"
638
+ else:
639
+ path_in_repo, revision, repo_type = divided
640
+ logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
641
+
642
+ list_files = huggingface_utils.list_dir(
643
+ repo_id=repo_id,
644
+ subfolder=path_in_repo,
645
+ revision=revision,
646
+ token=args.huggingface_token,
647
+ repo_type=repo_type,
648
+ )
649
+
650
+ async def download(filename) -> str:
651
+ def task():
652
+ return huggingface_hub.hf_hub_download(
653
+ repo_id=repo_id,
654
+ filename=filename,
655
+ revision=revision,
656
+ repo_type=repo_type,
657
+ token=args.huggingface_token,
658
+ )
659
+
660
+ return await asyncio.get_event_loop().run_in_executor(None, task)
661
+
662
+ loop = asyncio.get_event_loop()
663
+ results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
664
+ if len(results) == 0:
665
+ raise ValueError(
666
+ "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
667
+ )
668
+ dirname = os.path.dirname(results[0])
669
+ accelerator.load_state(dirname)
670
+
671
+ return True
672
+
673
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
674
+ pass
675
+
676
+ def get_noisy_model_input_and_timesteps(
677
+ self,
678
+ args: argparse.Namespace,
679
+ noise: torch.Tensor,
680
+ latents: torch.Tensor,
681
+ noise_scheduler: FlowMatchDiscreteScheduler,
682
+ device: torch.device,
683
+ dtype: torch.dtype,
684
+ ):
685
+ batch_size = noise.shape[0]
686
+
687
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
688
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
689
+ # Simple random t-based noise sampling
690
+ if args.timestep_sampling == "sigmoid":
691
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
692
+ else:
693
+ t = torch.rand((batch_size,), device=device)
694
+
695
+ elif args.timestep_sampling == "shift":
696
+ shift = args.discrete_flow_shift
697
+ logits_norm = torch.randn(batch_size, device=device)
698
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
699
+ t = logits_norm.sigmoid()
700
+ t = (t * shift) / (1 + (shift - 1) * t)
701
+
702
+ t_min = args.min_timestep if args.min_timestep is not None else 0
703
+ t_max = args.max_timestep if args.max_timestep is not None else 1000.0
704
+ t_min /= 1000.0
705
+ t_max /= 1000.0
706
+ t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
707
+
708
+ timesteps = t * 1000.0
709
+ t = t.view(-1, 1, 1, 1, 1)
710
+ noisy_model_input = (1 - t) * latents + t * noise
711
+
712
+ timesteps += 1 # 1 to 1000
713
+ else:
714
+ # Sample a random timestep for each image
715
+ # for weighting schemes where we sample timesteps non-uniformly
716
+ u = compute_density_for_timestep_sampling(
717
+ weighting_scheme=args.weighting_scheme,
718
+ batch_size=batch_size,
719
+ logit_mean=args.logit_mean,
720
+ logit_std=args.logit_std,
721
+ mode_scale=args.mode_scale,
722
+ )
723
+ # indices = (u * noise_scheduler.config.num_train_timesteps).long()
724
+ t_min = args.min_timestep if args.min_timestep is not None else 0
725
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
726
+ indices = (u * (t_max - t_min) + t_min).long()
727
+
728
+ timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
729
+
730
+ # Add noise according to flow matching.
731
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
732
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
733
+
734
+ return noisy_model_input, timesteps
735
+
736
+ def train(self, args):
737
+ if args.seed is None:
738
+ args.seed = random.randint(0, 2**32)
739
+ set_seed(args.seed)
740
+
741
+ # Load dataset config
742
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
743
+ logger.info(f"Load dataset config from {args.dataset_config}")
744
+ user_config = config_utils.load_user_config(args.dataset_config)
745
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
746
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
747
+
748
+ current_epoch = Value("i", 0)
749
+ current_step = Value("i", 0)
750
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
751
+ collator = collator_class(current_epoch, current_step, ds_for_collator)
752
+
753
+ # prepare accelerator
754
+ logger.info("preparing accelerator")
755
+ accelerator = prepare_accelerator(args)
756
+ is_main_process = accelerator.is_main_process
757
+
758
+ # prepare dtype
759
+ weight_dtype = torch.float32
760
+ if args.mixed_precision == "fp16":
761
+ weight_dtype = torch.float16
762
+ elif args.mixed_precision == "bf16":
763
+ weight_dtype = torch.bfloat16
764
+
765
+ # HunyuanVideo specific
766
+ vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
767
+
768
+ # get embedding for sampling images
769
+ sample_parameters = vae = None
770
+ if args.sample_prompts:
771
+ sample_parameters = self.process_sample_prompts(
772
+ args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
773
+ )
774
+
775
+ # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
776
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
777
+ vae.requires_grad_(False)
778
+ vae.eval()
779
+
780
+ if args.vae_chunk_size is not None:
781
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
782
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
783
+ if args.vae_spatial_tile_sample_min_size is not None:
784
+ vae.enable_spatial_tiling(True)
785
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
786
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
787
+ elif args.vae_tiling:
788
+ vae.enable_spatial_tiling(True)
789
+
790
+ # load DiT model
791
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
792
+ loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
793
+
794
+ logger.info(f"Loading DiT model from {args.dit}")
795
+ if args.sdpa:
796
+ attn_mode = "torch"
797
+ elif args.flash_attn:
798
+ attn_mode = "flash"
799
+ elif args.sage_attn:
800
+ attn_mode = "sageattn"
801
+ elif args.xformers:
802
+ attn_mode = "xformers"
803
+ else:
804
+ raise ValueError(
805
+ f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください"
806
+ )
807
+ transformer = load_transformer(
808
+ args.dit, attn_mode, args.split_attn, loading_device, None, in_channels=args.dit_in_channels
809
+ ) # load as is
810
+
811
+ if blocks_to_swap > 0:
812
+ logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
813
+ transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
814
+ transformer.move_to_device_except_swap_blocks(accelerator.device)
815
+ if args.img_in_txt_in_offloading:
816
+ logger.info("Enable offloading img_in and txt_in to CPU")
817
+ transformer.enable_img_in_txt_in_offloading()
818
+
819
+ if args.gradient_checkpointing:
820
+ transformer.enable_gradient_checkpointing()
821
+
822
+ # prepare optimizer, data loader etc.
823
+ accelerator.print("prepare optimizer, data loader etc.")
824
+
825
+ transformer.requires_grad_(False)
826
+ if accelerator.is_main_process:
827
+ accelerator.print(f"Trainable modules '{args.trainable_modules}'.")
828
+ for name, param in transformer.named_parameters():
829
+ for trainable_module_name in args.trainable_modules:
830
+ if trainable_module_name in name:
831
+ param.requires_grad = True
832
+ break
833
+
834
+ total_params = list(transformer.parameters())
835
+ trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
836
+ logger.info(
837
+ f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M"
838
+ )
839
+ optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
840
+ args, trainable_params
841
+ )
842
+
843
+ # prepare dataloader
844
+
845
+ # num workers for data loader: if 0, persistent_workers is not available
846
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
847
+
848
+ train_dataloader = torch.utils.data.DataLoader(
849
+ train_dataset_group,
850
+ batch_size=1,
851
+ shuffle=True,
852
+ collate_fn=collator,
853
+ num_workers=n_workers,
854
+ persistent_workers=args.persistent_data_loader_workers,
855
+ )
856
+
857
+ # calculate max_train_steps
858
+ if args.max_train_epochs is not None:
859
+ args.max_train_steps = args.max_train_epochs * math.ceil(
860
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
861
+ )
862
+ accelerator.print(
863
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
864
+ )
865
+
866
+ # send max_train_steps to train_dataset_group
867
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
868
+
869
+ # prepare lr_scheduler
870
+ lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
871
+
872
+ # prepare training model. accelerator does some magic here
873
+
874
+ # experimental feature: train the model with gradients in fp16/bf16
875
+ dit_dtype = torch.float32
876
+ if args.full_fp16:
877
+ assert (
878
+ args.mixed_precision == "fp16"
879
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
880
+ accelerator.print("enable full fp16 training.")
881
+ dit_weight_dtype = torch.float16
882
+ elif args.full_bf16:
883
+ assert (
884
+ args.mixed_precision == "bf16"
885
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
886
+ accelerator.print("enable full bf16 training.")
887
+ dit_weight_dtype = torch.bfloat16
888
+ else:
889
+ dit_weight_dtype = torch.float32
890
+
891
+ # TODO add fused optimizer and stochastic rounding
892
+
893
+ # cast model to dit_weight_dtype
894
+ # if dit_dtype != dit_weight_dtype:
895
+ logger.info(f"casting model to {dit_weight_dtype}")
896
+ transformer.to(dit_weight_dtype)
897
+
898
+ if blocks_to_swap > 0:
899
+ transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
900
+ accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
901
+ accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
902
+ else:
903
+ transformer = accelerator.prepare(transformer)
904
+
905
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
906
+
907
+ transformer.train()
908
+
909
+ if args.full_fp16:
910
+ # patch accelerator for fp16 training
911
+ # def patch_accelerator_for_fp16_training(accelerator):
912
+ org_unscale_grads = accelerator.scaler._unscale_grads_
913
+
914
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
915
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
916
+
917
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
918
+
919
+ # resume from local or huggingface. accelerator.step is set
920
+ self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
921
+
922
+ # epoch数を計算する
923
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
924
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
925
+
926
+ # 学習���る
927
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
928
+
929
+ accelerator.print("running training / 学習開始")
930
+ accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
931
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
932
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
933
+ accelerator.print(
934
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
935
+ )
936
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
937
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
938
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
939
+
940
+ if accelerator.is_main_process:
941
+ init_kwargs = {}
942
+ if args.wandb_run_name:
943
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
944
+ if args.log_tracker_config is not None:
945
+ init_kwargs = toml.load(args.log_tracker_config)
946
+ accelerator.init_trackers(
947
+ "hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name,
948
+ config=train_utils.get_sanitized_config_or_none(args),
949
+ init_kwargs=init_kwargs,
950
+ )
951
+
952
+ # TODO skip until initial step
953
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
954
+
955
+ epoch_to_start = 0
956
+ global_step = 0
957
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
958
+
959
+ loss_recorder = train_utils.LossRecorder()
960
+ del train_dataset_group
961
+
962
+ # function for saving/removing
963
+ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
964
+ os.makedirs(args.output_dir, exist_ok=True)
965
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
966
+
967
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
968
+
969
+ title = args.metadata_title if args.metadata_title is not None else args.output_name
970
+ if args.min_timestep is not None or args.max_timestep is not None:
971
+ min_time_step = args.min_timestep if args.min_timestep is not None else 0
972
+ max_time_step = args.max_timestep if args.max_timestep is not None else 1000
973
+ md_timesteps = (min_time_step, max_time_step)
974
+ else:
975
+ md_timesteps = None
976
+
977
+ sai_metadata = sai_model_spec.build_metadata(
978
+ None,
979
+ ARCHITECTURE_HUNYUAN_VIDEO,
980
+ time.time(),
981
+ title,
982
+ None,
983
+ args.metadata_author,
984
+ args.metadata_description,
985
+ args.metadata_license,
986
+ args.metadata_tags,
987
+ timesteps=md_timesteps,
988
+ is_lora=False,
989
+ )
990
+
991
+ save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata)
992
+ if args.huggingface_repo_id is not None:
993
+ huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
994
+
995
+ def remove_model(old_ckpt_name):
996
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
997
+ if os.path.exists(old_ckpt_file):
998
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
999
+ os.remove(old_ckpt_file)
1000
+
1001
+ # For --sample_at_first
1002
+ optimizer_eval_fn()
1003
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
1004
+ optimizer_train_fn()
1005
+ if len(accelerator.trackers) > 0:
1006
+ # log empty object to commit the sample images to wandb
1007
+ accelerator.log({}, step=0)
1008
+
1009
+ # training loop
1010
+
1011
+ # log device and dtype for each model
1012
+ logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
1013
+
1014
+ clean_memory_on_device(accelerator.device)
1015
+
1016
+ pos_embed_cache = {}
1017
+
1018
+ for epoch in range(epoch_to_start, num_train_epochs):
1019
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
1020
+ current_epoch.value = epoch + 1
1021
+
1022
+ for step, batch in enumerate(train_dataloader):
1023
+ latents, llm_embeds, llm_mask, clip_embeds = batch
1024
+ bsz = latents.shape[0]
1025
+ current_step.value = global_step
1026
+
1027
+ with accelerator.accumulate(transformer):
1028
+ latents = latents * vae_module.SCALING_FACTOR
1029
+
1030
+ # Sample noise that we'll add to the latents
1031
+ noise = torch.randn_like(latents)
1032
+
1033
+ # calculate model input and timesteps
1034
+ noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
1035
+ args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
1036
+ )
1037
+
1038
+ weighting = compute_loss_weighting_for_sd3(
1039
+ args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
1040
+ )
1041
+
1042
+ # ensure guidance_scale in args is float
1043
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
1044
+
1045
+ # ensure the hidden state will require grad
1046
+ if args.gradient_checkpointing:
1047
+ noisy_model_input.requires_grad_(True)
1048
+ guidance_vec.requires_grad_(True)
1049
+
1050
+ pos_emb_shape = latents.shape[1:]
1051
+ if pos_emb_shape not in pos_embed_cache:
1052
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(
1053
+ accelerator.unwrap_model(transformer), latents.shape[2:]
1054
+ )
1055
+ # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
1056
+ # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
1057
+ pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
1058
+ else:
1059
+ freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
1060
+
1061
+ # call DiT
1062
+ latents = latents.to(device=accelerator.device, dtype=dit_dtype)
1063
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype)
1064
+ # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
1065
+ # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
1066
+ # llm_mask = llm_mask.to(device=accelerator.device)
1067
+ # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
1068
+ with accelerator.autocast():
1069
+ model_pred = transformer(
1070
+ noisy_model_input,
1071
+ timesteps,
1072
+ text_states=llm_embeds,
1073
+ text_mask=llm_mask,
1074
+ text_states_2=clip_embeds,
1075
+ freqs_cos=freqs_cos,
1076
+ freqs_sin=freqs_sin,
1077
+ guidance=guidance_vec,
1078
+ return_dict=False,
1079
+ )
1080
+
1081
+ # flow matching loss
1082
+ target = noise - latents
1083
+
1084
+ loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none")
1085
+
1086
+ if weighting is not None:
1087
+ loss = loss * weighting
1088
+ # loss = loss.mean([1, 2, 3])
1089
+ # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
1090
+ # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
1091
+
1092
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
1093
+
1094
+ accelerator.backward(loss)
1095
+ if accelerator.sync_gradients:
1096
+ # self.all_reduce_network(accelerator, network) # sync DDP grad manually
1097
+ state = accelerate.PartialState()
1098
+ if state.distributed_type != accelerate.DistributedType.NO:
1099
+ for param in transformer.parameters():
1100
+ if param.grad is not None:
1101
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
1102
+
1103
+ if args.max_grad_norm != 0.0:
1104
+ params_to_clip = accelerator.unwrap_model(transformer).parameters()
1105
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1106
+
1107
+ optimizer.step()
1108
+ lr_scheduler.step()
1109
+ optimizer.zero_grad(set_to_none=True)
1110
+
1111
+ # Checks if the accelerator has performed an optimization step behind the scenes
1112
+ if accelerator.sync_gradients:
1113
+ progress_bar.update(1)
1114
+ global_step += 1
1115
+
1116
+ optimizer_eval_fn()
1117
+ self.sample_images(
1118
+ accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
1119
+ )
1120
+
1121
+ # 指定ステップごとにモデルを保存
1122
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
1123
+ accelerator.wait_for_everyone()
1124
+ if accelerator.is_main_process:
1125
+ ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
1126
+ save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch)
1127
+
1128
+ if args.save_state:
1129
+ train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
1130
+
1131
+ remove_step_no = train_utils.get_remove_step_no(args, global_step)
1132
+ if remove_step_no is not None:
1133
+ remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
1134
+ remove_model(remove_ckpt_name)
1135
+ optimizer_train_fn()
1136
+
1137
+ current_loss = loss.detach().item()
1138
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
1139
+ avr_loss: float = loss_recorder.moving_average
1140
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
1141
+ progress_bar.set_postfix(**logs)
1142
+
1143
+ if len(accelerator.trackers) > 0:
1144
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
1145
+ accelerator.log(logs, step=global_step)
1146
+
1147
+ if global_step >= args.max_train_steps:
1148
+ break
1149
+
1150
+ if len(accelerator.trackers) > 0:
1151
+ logs = {"loss/epoch": loss_recorder.moving_average}
1152
+ accelerator.log(logs, step=epoch + 1)
1153
+
1154
+ accelerator.wait_for_everyone()
1155
+
1156
+ # 指定エポックごとにモデルを保存
1157
+ optimizer_eval_fn()
1158
+ if args.save_every_n_epochs is not None:
1159
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
1160
+ if is_main_process and saving:
1161
+ ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
1162
+ save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1)
1163
+
1164
+ remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
1165
+ if remove_epoch_no is not None:
1166
+ remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
1167
+ remove_model(remove_ckpt_name)
1168
+
1169
+ if args.save_state:
1170
+ train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
1171
+
1172
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
1173
+ optimizer_train_fn()
1174
+
1175
+ # end of epoch
1176
+
1177
+ if is_main_process:
1178
+ transformer = accelerator.unwrap_model(transformer)
1179
+
1180
+ accelerator.end_training()
1181
+ optimizer_eval_fn()
1182
+
1183
+ if args.save_state or args.save_state_on_train_end:
1184
+ train_utils.save_state_on_train_end(args, accelerator)
1185
+
1186
+ if is_main_process:
1187
+ ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
1188
+ save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True)
1189
+
1190
+ logger.info("model saved.")
1191
+
1192
+
1193
+ def setup_parser() -> argparse.ArgumentParser:
1194
+ def int_or_float(value):
1195
+ if value.endswith("%"):
1196
+ try:
1197
+ return float(value[:-1]) / 100.0
1198
+ except ValueError:
1199
+ raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
1200
+ try:
1201
+ float_value = float(value)
1202
+ if float_value >= 1 and float_value.is_integer():
1203
+ return int(value)
1204
+ return float(value)
1205
+ except ValueError:
1206
+ raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
1207
+
1208
+ parser = argparse.ArgumentParser()
1209
+
1210
+ # general settings
1211
+ parser.add_argument(
1212
+ "--config_file",
1213
+ type=str,
1214
+ default=None,
1215
+ help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
1216
+ )
1217
+ parser.add_argument(
1218
+ "--dataset_config",
1219
+ type=pathlib.Path,
1220
+ default=None,
1221
+ required=True,
1222
+ help="config file for dataset / データセットの設定ファイル",
1223
+ )
1224
+
1225
+ # training settings
1226
+ parser.add_argument(
1227
+ "--sdpa",
1228
+ action="store_true",
1229
+ help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
1230
+ )
1231
+ parser.add_argument(
1232
+ "--flash_attn",
1233
+ action="store_true",
1234
+ help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
1235
+ )
1236
+ parser.add_argument(
1237
+ "--sage_attn",
1238
+ action="store_true",
1239
+ help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
1240
+ )
1241
+ parser.add_argument(
1242
+ "--xformers",
1243
+ action="store_true",
1244
+ help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要",
1245
+ )
1246
+ parser.add_argument(
1247
+ "--split_attn",
1248
+ action="store_true",
1249
+ help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)"
1250
+ " / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)",
1251
+ )
1252
+
1253
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1254
+ parser.add_argument(
1255
+ "--max_train_epochs",
1256
+ type=int,
1257
+ default=None,
1258
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
1259
+ )
1260
+ parser.add_argument(
1261
+ "--max_data_loader_n_workers",
1262
+ type=int,
1263
+ default=8,
1264
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
1265
+ )
1266
+ parser.add_argument(
1267
+ "--persistent_data_loader_workers",
1268
+ action="store_true",
1269
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
1270
+ )
1271
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1272
+ parser.add_argument(
1273
+ "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
1274
+ )
1275
+ parser.add_argument(
1276
+ "--gradient_accumulation_steps",
1277
+ type=int,
1278
+ default=1,
1279
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
1280
+ )
1281
+ parser.add_argument(
1282
+ "--mixed_precision",
1283
+ type=str,
1284
+ default="no",
1285
+ choices=["no", "fp16", "bf16"],
1286
+ help="use mixed precision / 混合精度を使う場合、その精度",
1287
+ )
1288
+ parser.add_argument("--trainable_modules", nargs="+", default=".", help="Enter a list of trainable modules")
1289
+
1290
+ parser.add_argument(
1291
+ "--logging_dir",
1292
+ type=str,
1293
+ default=None,
1294
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
1295
+ )
1296
+ parser.add_argument(
1297
+ "--log_with",
1298
+ type=str,
1299
+ default=None,
1300
+ choices=["tensorboard", "wandb", "all"],
1301
+ help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
1302
+ )
1303
+ parser.add_argument(
1304
+ "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
1305
+ )
1306
+ parser.add_argument(
1307
+ "--log_tracker_name",
1308
+ type=str,
1309
+ default=None,
1310
+ help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
1311
+ )
1312
+ parser.add_argument(
1313
+ "--wandb_run_name",
1314
+ type=str,
1315
+ default=None,
1316
+ help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
1317
+ )
1318
+ parser.add_argument(
1319
+ "--log_tracker_config",
1320
+ type=str,
1321
+ default=None,
1322
+ help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
1323
+ )
1324
+ parser.add_argument(
1325
+ "--wandb_api_key",
1326
+ type=str,
1327
+ default=None,
1328
+ help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
1329
+ )
1330
+ parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
1331
+
1332
+ parser.add_argument(
1333
+ "--ddp_timeout",
1334
+ type=int,
1335
+ default=None,
1336
+ help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
1337
+ )
1338
+ parser.add_argument(
1339
+ "--ddp_gradient_as_bucket_view",
1340
+ action="store_true",
1341
+ help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
1342
+ )
1343
+ parser.add_argument(
1344
+ "--ddp_static_graph",
1345
+ action="store_true",
1346
+ help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
1347
+ )
1348
+
1349
+ parser.add_argument(
1350
+ "--sample_every_n_steps",
1351
+ type=int,
1352
+ default=None,
1353
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
1354
+ )
1355
+ parser.add_argument(
1356
+ "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
1357
+ )
1358
+ parser.add_argument(
1359
+ "--sample_every_n_epochs",
1360
+ type=int,
1361
+ default=None,
1362
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
1363
+ )
1364
+ parser.add_argument(
1365
+ "--sample_prompts",
1366
+ type=str,
1367
+ default=None,
1368
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
1369
+ )
1370
+
1371
+ # optimizer and lr scheduler settings
1372
+ parser.add_argument(
1373
+ "--optimizer_type",
1374
+ type=str,
1375
+ default="",
1376
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
1377
+ "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
1378
+ )
1379
+ parser.add_argument(
1380
+ "--optimizer_args",
1381
+ type=str,
1382
+ default=None,
1383
+ nargs="*",
1384
+ help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
1385
+ )
1386
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1387
+ parser.add_argument(
1388
+ "--max_grad_norm",
1389
+ default=1.0,
1390
+ type=float,
1391
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
1392
+ )
1393
+
1394
+ parser.add_argument(
1395
+ "--lr_scheduler",
1396
+ type=str,
1397
+ default="constant",
1398
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
1399
+ )
1400
+ parser.add_argument(
1401
+ "--lr_warmup_steps",
1402
+ type=int_or_float,
1403
+ default=0,
1404
+ help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
1405
+ " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
1406
+ )
1407
+ parser.add_argument(
1408
+ "--lr_decay_steps",
1409
+ type=int_or_float,
1410
+ default=0,
1411
+ help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
1412
+ " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
1413
+ )
1414
+ parser.add_argument(
1415
+ "--lr_scheduler_num_cycles",
1416
+ type=int,
1417
+ default=1,
1418
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
1419
+ )
1420
+ parser.add_argument(
1421
+ "--lr_scheduler_power",
1422
+ type=float,
1423
+ default=1,
1424
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
1425
+ )
1426
+ parser.add_argument(
1427
+ "--lr_scheduler_timescale",
1428
+ type=int,
1429
+ default=None,
1430
+ help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
1431
+ + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
1432
+ )
1433
+ parser.add_argument(
1434
+ "--lr_scheduler_min_lr_ratio",
1435
+ type=float,
1436
+ default=None,
1437
+ help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
1438
+ + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
1439
+ )
1440
+ parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
1441
+ parser.add_argument(
1442
+ "--lr_scheduler_args",
1443
+ type=str,
1444
+ default=None,
1445
+ nargs="*",
1446
+ help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
1447
+ )
1448
+
1449
+ # model settings
1450
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
1451
+ parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
1452
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
1453
+ parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
1454
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
1455
+ parser.add_argument(
1456
+ "--vae_tiling",
1457
+ action="store_true",
1458
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
1459
+ " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
1460
+ )
1461
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
1462
+ parser.add_argument(
1463
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
1464
+ )
1465
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
1466
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
1467
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
1468
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
1469
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1470
+ parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
1471
+
1472
+ parser.add_argument(
1473
+ "--blocks_to_swap",
1474
+ type=int,
1475
+ default=None,
1476
+ help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
1477
+ )
1478
+ parser.add_argument(
1479
+ "--img_in_txt_in_offloading",
1480
+ action="store_true",
1481
+ help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
1482
+ )
1483
+
1484
+ # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
1485
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
1486
+ parser.add_argument(
1487
+ "--timestep_sampling",
1488
+ choices=["sigma", "uniform", "sigmoid", "shift"],
1489
+ default="sigma",
1490
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
1491
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
1492
+ )
1493
+ parser.add_argument(
1494
+ "--discrete_flow_shift",
1495
+ type=float,
1496
+ default=1.0,
1497
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
1498
+ )
1499
+ parser.add_argument(
1500
+ "--sigmoid_scale",
1501
+ type=float,
1502
+ default=1.0,
1503
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
1504
+ )
1505
+ parser.add_argument(
1506
+ "--weighting_scheme",
1507
+ type=str,
1508
+ default="none",
1509
+ choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
1510
+ help="weighting scheme for timestep distribution. Default is none"
1511
+ " / タイムステップ分布の重み付けスキーム、デフォルトはnone",
1512
+ )
1513
+ parser.add_argument(
1514
+ "--logit_mean",
1515
+ type=float,
1516
+ default=0.0,
1517
+ help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
1518
+ )
1519
+ parser.add_argument(
1520
+ "--logit_std",
1521
+ type=float,
1522
+ default=1.0,
1523
+ help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
1524
+ )
1525
+ parser.add_argument(
1526
+ "--mode_scale",
1527
+ type=float,
1528
+ default=1.29,
1529
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
1530
+ )
1531
+ parser.add_argument(
1532
+ "--min_timestep",
1533
+ type=int,
1534
+ default=None,
1535
+ help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
1536
+ )
1537
+ parser.add_argument(
1538
+ "--max_timestep",
1539
+ type=int,
1540
+ default=None,
1541
+ help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
1542
+ )
1543
+
1544
+ # save and load settings
1545
+ parser.add_argument(
1546
+ "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
1547
+ )
1548
+ parser.add_argument(
1549
+ "--output_name",
1550
+ type=str,
1551
+ default=None,
1552
+ required=True,
1553
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
1554
+ )
1555
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1556
+
1557
+ parser.add_argument(
1558
+ "--save_every_n_epochs",
1559
+ type=int,
1560
+ default=None,
1561
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
1562
+ )
1563
+ parser.add_argument(
1564
+ "--save_every_n_steps",
1565
+ type=int,
1566
+ default=None,
1567
+ help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
1568
+ )
1569
+ parser.add_argument(
1570
+ "--save_last_n_epochs",
1571
+ type=int,
1572
+ default=None,
1573
+ help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
1574
+ )
1575
+ parser.add_argument(
1576
+ "--save_last_n_epochs_state",
1577
+ type=int,
1578
+ default=None,
1579
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
1580
+ )
1581
+ parser.add_argument(
1582
+ "--save_last_n_steps",
1583
+ type=int,
1584
+ default=None,
1585
+ help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
1586
+ )
1587
+ parser.add_argument(
1588
+ "--save_last_n_steps_state",
1589
+ type=int,
1590
+ default=None,
1591
+ help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
1592
+ )
1593
+ parser.add_argument(
1594
+ "--save_state",
1595
+ action="store_true",
1596
+ help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
1597
+ )
1598
+ parser.add_argument(
1599
+ "--save_state_on_train_end",
1600
+ action="store_true",
1601
+ help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
1602
+ " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
1603
+ )
1604
+
1605
+ # SAI Model spec
1606
+ parser.add_argument(
1607
+ "--metadata_title",
1608
+ type=str,
1609
+ default=None,
1610
+ help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
1611
+ )
1612
+ parser.add_argument(
1613
+ "--metadata_author",
1614
+ type=str,
1615
+ default=None,
1616
+ help="author name for model metadata / メタデータに書き込まれるモデル作者名",
1617
+ )
1618
+ parser.add_argument(
1619
+ "--metadata_description",
1620
+ type=str,
1621
+ default=None,
1622
+ help="description for model metadata / メタデータに書き込まれるモデル説明",
1623
+ )
1624
+ parser.add_argument(
1625
+ "--metadata_license",
1626
+ type=str,
1627
+ default=None,
1628
+ help="license for model metadata / メタデータに書き込まれるモデルライセンス",
1629
+ )
1630
+ parser.add_argument(
1631
+ "--metadata_tags",
1632
+ type=str,
1633
+ default=None,
1634
+ help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
1635
+ )
1636
+
1637
+ # huggingface settings
1638
+ parser.add_argument(
1639
+ "--huggingface_repo_id",
1640
+ type=str,
1641
+ default=None,
1642
+ help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
1643
+ )
1644
+ parser.add_argument(
1645
+ "--huggingface_repo_type",
1646
+ type=str,
1647
+ default=None,
1648
+ help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
1649
+ )
1650
+ parser.add_argument(
1651
+ "--huggingface_path_in_repo",
1652
+ type=str,
1653
+ default=None,
1654
+ help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
1655
+ )
1656
+ parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
1657
+ parser.add_argument(
1658
+ "--huggingface_repo_visibility",
1659
+ type=str,
1660
+ default=None,
1661
+ help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
1662
+ )
1663
+ parser.add_argument(
1664
+ "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
1665
+ )
1666
+ parser.add_argument(
1667
+ "--resume_from_huggingface",
1668
+ action="store_true",
1669
+ help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
1670
+ )
1671
+ parser.add_argument(
1672
+ "--async_upload",
1673
+ action="store_true",
1674
+ help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
1675
+ )
1676
+
1677
+ return parser
1678
+
1679
+
1680
+ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
1681
+ if not args.config_file:
1682
+ return args
1683
+
1684
+ config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
1685
+
1686
+ if not os.path.exists(config_path):
1687
+ logger.info(f"{config_path} not found.")
1688
+ exit(1)
1689
+
1690
+ logger.info(f"Loading settings from {config_path}...")
1691
+ with open(config_path, "r", encoding="utf-8") as f:
1692
+ config_dict = toml.load(f)
1693
+
1694
+ # combine all sections into one
1695
+ ignore_nesting_dict = {}
1696
+ for section_name, section_dict in config_dict.items():
1697
+ # if value is not dict, save key and value as is
1698
+ if not isinstance(section_dict, dict):
1699
+ ignore_nesting_dict[section_name] = section_dict
1700
+ continue
1701
+
1702
+ # if value is dict, save all key and value into one dict
1703
+ for key, value in section_dict.items():
1704
+ ignore_nesting_dict[key] = value
1705
+
1706
+ config_args = argparse.Namespace(**ignore_nesting_dict)
1707
+ args = parser.parse_args(namespace=config_args)
1708
+ args.config_file = os.path.splitext(args.config_file)[0]
1709
+ logger.info(args.config_file)
1710
+
1711
+ return args
1712
+
1713
+
1714
+ if __name__ == "__main__":
1715
+ parser = setup_parser()
1716
+
1717
+ args = parser.parse_args()
1718
+ args = read_config_from_file(args, parser)
1719
+
1720
+ trainer = FineTuningTrainer()
1721
+ trainer.train(args)
hv_train_network.py ADDED
The diff for this file is too large to render. See raw diff
 
wan_cache_latents.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from dataset import config_utils
11
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
12
+ from PIL import Image
13
+
14
+ import logging
15
+
16
+ from dataset.image_video_dataset import ItemInfo, save_latent_cache_wan, ARCHITECTURE_WAN
17
+ from utils.model_utils import str_to_dtype
18
+ from wan.configs import wan_i2v_14B
19
+ from wan.modules.vae import WanVAE
20
+ from wan.modules.clip import CLIPModel
21
+ import cache_latents
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ def encode_and_save_batch(vae: WanVAE, clip: Optional[CLIPModel], batch: list[ItemInfo]):
28
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
29
+ if len(contents.shape) == 4:
30
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
31
+
32
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
33
+ contents = contents.to(vae.device, dtype=vae.dtype)
34
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
35
+
36
+ h, w = contents.shape[3], contents.shape[4]
37
+ if h < 8 or w < 8:
38
+ item = batch[0] # other items should have the same size
39
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
40
+
41
+ # print(f"encode batch: {contents.shape}")
42
+ with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
43
+ latent = vae.encode(contents) # list of Tensor[C, F, H, W]
44
+ latent = torch.stack(latent, dim=0) # B, C, F, H, W
45
+ latent = latent.to(vae.dtype) # convert to bfloat16, we are not sure if this is correct
46
+
47
+ if clip is not None:
48
+ # extract first frame of contents
49
+ images = contents[:, :, 0:1, :, :] # B, C, F, H, W, non contiguous view is fine
50
+
51
+ with torch.amp.autocast(device_type=clip.device.type, dtype=torch.float16), torch.no_grad():
52
+ clip_context = clip.visual(images)
53
+ clip_context = clip_context.to(torch.float16) # convert to fp16
54
+
55
+ # encode image latent for I2V
56
+ B, _, _, lat_h, lat_w = latent.shape
57
+ F = contents.shape[2]
58
+
59
+ # Create mask for the required number of frames
60
+ msk = torch.ones(1, F, lat_h, lat_w, dtype=vae.dtype, device=vae.device)
61
+ msk[:, 1:] = 0
62
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
63
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
64
+ msk = msk.transpose(1, 2) # 1, F, 4, H, W -> 1, 4, F, H, W
65
+ msk = msk.repeat(B, 1, 1, 1, 1) # B, 4, F, H, W
66
+
67
+ # Zero padding for the required number of frames only
68
+ padding_frames = F - 1 # The first frame is the input image
69
+ images_resized = torch.concat([images, torch.zeros(B, 3, padding_frames, h, w, device=vae.device)], dim=2)
70
+ with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
71
+ y = vae.encode(images_resized)
72
+ y = torch.stack(y, dim=0) # B, C, F, H, W
73
+
74
+ y = y[:, :, :F] # may be not needed
75
+ y = y.to(vae.dtype) # convert to bfloat16
76
+ y = torch.concat([msk, y], dim=1) # B, 4 + C, F, H, W
77
+
78
+ else:
79
+ clip_context = None
80
+ y = None
81
+
82
+ # control videos
83
+ if batch[0].control_content is not None:
84
+ control_contents = torch.stack([torch.from_numpy(item.control_content) for item in batch])
85
+ if len(control_contents.shape) == 4:
86
+ control_contents = control_contents.unsqueeze(1)
87
+ control_contents = control_contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
88
+ control_contents = control_contents.to(vae.device, dtype=vae.dtype)
89
+ control_contents = control_contents / 127.5 - 1.0 # normalize to [-1, 1]
90
+ with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
91
+ control_latent = vae.encode(control_contents) # list of Tensor[C, F, H, W]
92
+ control_latent = torch.stack(control_latent, dim=0) # B, C, F, H, W
93
+ control_latent = control_latent.to(vae.dtype) # convert to bfloat16
94
+ else:
95
+ control_latent = None
96
+
97
+ # # debug: decode and save
98
+ # with torch.no_grad():
99
+ # latent_to_decode = latent / vae.config.scaling_factor
100
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
101
+ # images = (images / 2 + 0.5).clamp(0, 1)
102
+ # images = images.cpu().float().numpy()
103
+ # images = (images * 255).astype(np.uint8)
104
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
105
+ # for b in range(images.shape[0]):
106
+ # for f in range(images.shape[1]):
107
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
108
+ # img = Image.fromarray(images[b, f])
109
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
110
+
111
+ for i, item in enumerate(batch):
112
+ l = latent[i]
113
+ cctx = clip_context[i] if clip is not None else None
114
+ y_i = y[i] if clip is not None else None
115
+ control_latent_i = control_latent[i] if control_latent is not None else None
116
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
117
+ save_latent_cache_wan(item, l, cctx, y_i, control_latent_i)
118
+
119
+
120
+ def main(args):
121
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
122
+ device = torch.device(device)
123
+
124
+ # Load dataset config
125
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
126
+ logger.info(f"Load dataset config from {args.dataset_config}")
127
+ user_config = config_utils.load_user_config(args.dataset_config)
128
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
129
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
130
+
131
+ datasets = train_dataset_group.datasets
132
+
133
+ if args.debug_mode is not None:
134
+ cache_latents.show_datasets(
135
+ datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
136
+ )
137
+ return
138
+
139
+ assert args.vae is not None, "vae checkpoint is required"
140
+
141
+ vae_path = args.vae
142
+
143
+ logger.info(f"Loading VAE model from {vae_path}")
144
+ vae_dtype = torch.bfloat16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
145
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
146
+ vae = WanVAE(vae_path=vae_path, device=device, dtype=vae_dtype, cache_device=cache_device)
147
+
148
+ if args.clip is not None:
149
+ clip_dtype = wan_i2v_14B.i2v_14B["clip_dtype"]
150
+ clip = CLIPModel(dtype=clip_dtype, device=device, weight_path=args.clip)
151
+ else:
152
+ clip = None
153
+
154
+ # Encode images
155
+ def encode(one_batch: list[ItemInfo]):
156
+ encode_and_save_batch(vae, clip, one_batch)
157
+
158
+ cache_latents.encode_datasets(datasets, encode, args)
159
+
160
+
161
+ def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
162
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
163
+ parser.add_argument(
164
+ "--clip",
165
+ type=str,
166
+ default=None,
167
+ help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
168
+ )
169
+ return parser
170
+
171
+
172
+ if __name__ == "__main__":
173
+ parser = cache_latents.setup_parser_common()
174
+ parser = wan_setup_parser(parser)
175
+
176
+ args = parser.parse_args()
177
+ main(args)
wan_cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan
14
+
15
+ # for t5 config: all Wan2.1 models have the same config for t5
16
+ from wan.configs import wan_t2v_14B
17
+
18
+ import cache_text_encoder_outputs
19
+ import logging
20
+
21
+ from utils.model_utils import str_to_dtype
22
+ from wan.modules.t5 import T5EncoderModel
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logging.basicConfig(level=logging.INFO)
26
+
27
+
28
+ def encode_and_save_batch(
29
+ text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator]
30
+ ):
31
+ prompts = [item.caption for item in batch]
32
+ # print(prompts)
33
+
34
+ # encode prompt
35
+ with torch.no_grad():
36
+ if accelerator is not None:
37
+ with accelerator.autocast():
38
+ context = text_encoder(prompts, device)
39
+ else:
40
+ context = text_encoder(prompts, device)
41
+
42
+ # save prompt cache
43
+ for item, ctx in zip(batch, context):
44
+ save_text_encoder_output_cache_wan(item, ctx)
45
+
46
+
47
+ def main(args):
48
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
49
+ device = torch.device(device)
50
+
51
+ # Load dataset config
52
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
53
+ logger.info(f"Load dataset config from {args.dataset_config}")
54
+ user_config = config_utils.load_user_config(args.dataset_config)
55
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
56
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
57
+
58
+ datasets = train_dataset_group.datasets
59
+
60
+ # define accelerator for fp8 inference
61
+ config = wan_t2v_14B.t2v_14B # all Wan2.1 models have the same config for t5
62
+ accelerator = None
63
+ if args.fp8_t5:
64
+ accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16")
65
+
66
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
67
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
68
+
69
+ # Load T5
70
+ logger.info(f"Loading T5: {args.t5}")
71
+ text_encoder = T5EncoderModel(
72
+ text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5
73
+ )
74
+
75
+ # Encode with T5
76
+ logger.info("Encoding with T5")
77
+
78
+ def encode_for_text_encoder(batch: list[ItemInfo]):
79
+ encode_and_save_batch(text_encoder, batch, device, accelerator)
80
+
81
+ cache_text_encoder_outputs.process_text_encoder_batches(
82
+ args.num_workers,
83
+ args.skip_existing,
84
+ args.batch_size,
85
+ datasets,
86
+ all_cache_files_for_dataset,
87
+ all_cache_paths_for_dataset,
88
+ encode_for_text_encoder,
89
+ )
90
+ del text_encoder
91
+
92
+ # remove cache files not in dataset
93
+ cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
94
+
95
+
96
+ def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
97
+ parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path")
98
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
99
+ return parser
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = cache_text_encoder_outputs.setup_parser_common()
104
+ parser = wan_setup_parser(parser)
105
+
106
+ args = parser.parse_args()
107
+ main(args)
wan_generate_video.py ADDED
@@ -0,0 +1,1902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ import gc
4
+ import random
5
+ import os
6
+ import re
7
+ import time
8
+ import math
9
+ import copy
10
+ from types import ModuleType, SimpleNamespace
11
+ from typing import Tuple, Optional, List, Union, Any, Dict
12
+
13
+ import torch
14
+ import accelerate
15
+ from accelerate import Accelerator
16
+ from safetensors.torch import load_file, save_file
17
+ from safetensors import safe_open
18
+ from PIL import Image
19
+ import cv2
20
+ import numpy as np
21
+ import torchvision.transforms.functional as TF
22
+ from tqdm import tqdm
23
+
24
+ from networks import lora_wan
25
+ from utils.safetensors_utils import mem_eff_save_file, load_safetensors
26
+ from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES
27
+ import wan
28
+ from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype
29
+ from wan.modules.vae import WanVAE
30
+ from wan.modules.t5 import T5EncoderModel
31
+ from wan.modules.clip import CLIPModel
32
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
33
+ from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
34
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
35
+
36
+ try:
37
+ from lycoris.kohya import create_network_from_weights
38
+ except:
39
+ pass
40
+
41
+ from utils.model_utils import str_to_dtype
42
+ from utils.device_utils import clean_memory_on_device
43
+ from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
44
+ from dataset.image_video_dataset import load_video
45
+
46
+ import logging
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logging.basicConfig(level=logging.INFO)
50
+
51
+
52
+ class GenerationSettings:
53
+ def __init__(
54
+ self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype
55
+ ):
56
+ self.device = device
57
+ self.cfg = cfg
58
+ self.dit_dtype = dit_dtype
59
+ self.dit_weight_dtype = dit_weight_dtype
60
+ self.vae_dtype = vae_dtype
61
+
62
+
63
+ def parse_args() -> argparse.Namespace:
64
+ """parse command line arguments"""
65
+ parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
66
+
67
+ # WAN arguments
68
+ parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
69
+ parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
70
+ parser.add_argument(
71
+ "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
72
+ )
73
+
74
+ parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path")
75
+ parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path")
76
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16")
77
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
78
+ parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
79
+ parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path")
80
+ # LoRA
81
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
82
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
83
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
84
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
85
+ parser.add_argument(
86
+ "--save_merged_model",
87
+ type=str,
88
+ default=None,
89
+ help="Save merged model to path. If specified, no inference will be performed.",
90
+ )
91
+
92
+ # inference
93
+ parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
94
+ parser.add_argument(
95
+ "--negative_prompt",
96
+ type=str,
97
+ default=None,
98
+ help="negative prompt for generation, use default negative prompt if not specified",
99
+ )
100
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
101
+ parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task")
102
+ parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16")
103
+ parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps")
104
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
105
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
106
+ parser.add_argument(
107
+ "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
108
+ )
109
+ parser.add_argument(
110
+ "--guidance_scale",
111
+ type=float,
112
+ default=5.0,
113
+ help="Guidance scale for classifier free guidance. Default is 5.0.",
114
+ )
115
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
116
+ parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference")
117
+ parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
118
+ parser.add_argument(
119
+ "--control_path",
120
+ type=str,
121
+ default=None,
122
+ help="path to control video for inference with controlnet. video file or directory with images",
123
+ )
124
+ parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
125
+ parser.add_argument(
126
+ "--cfg_skip_mode",
127
+ type=str,
128
+ default="none",
129
+ choices=["early", "late", "middle", "early_late", "alternate", "none"],
130
+ help="CFG skip mode. each mode skips different parts of the CFG. "
131
+ " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)",
132
+ )
133
+ parser.add_argument(
134
+ "--cfg_apply_ratio",
135
+ type=float,
136
+ default=None,
137
+ help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).",
138
+ )
139
+ parser.add_argument(
140
+ "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated"
141
+ )
142
+ parser.add_argument(
143
+ "--slg_scale",
144
+ type=float,
145
+ default=3.0,
146
+ help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond",
147
+ )
148
+ parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.")
149
+ parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.")
150
+ parser.add_argument(
151
+ "--slg_mode",
152
+ type=str,
153
+ default=None,
154
+ choices=["original", "uncond"],
155
+ help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred",
156
+ )
157
+
158
+ # Flow Matching
159
+ parser.add_argument(
160
+ "--flow_shift",
161
+ type=float,
162
+ default=None,
163
+ help="Shift factor for flow matching schedulers. Default depends on task.",
164
+ )
165
+
166
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
167
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
168
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
169
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
170
+ parser.add_argument(
171
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
172
+ )
173
+ parser.add_argument(
174
+ "--attn_mode",
175
+ type=str,
176
+ default="torch",
177
+ choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"],
178
+ help="attention mode",
179
+ )
180
+ parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
181
+ parser.add_argument(
182
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
183
+ )
184
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
185
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
186
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
187
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
188
+ parser.add_argument(
189
+ "--compile_args",
190
+ nargs=4,
191
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
192
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
193
+ help="Torch.compile settings",
194
+ )
195
+
196
+ # New arguments for batch and interactive modes
197
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
198
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
199
+
200
+ args = parser.parse_args()
201
+
202
+ # Validate arguments
203
+ if args.from_file and args.interactive:
204
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
205
+
206
+ if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None:
207
+ raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified")
208
+
209
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
210
+ args.output_type == "images" or args.output_type == "video"
211
+ ), "latent_path is only supported for images or video output"
212
+
213
+ return args
214
+
215
+
216
+ def parse_prompt_line(line: str) -> Dict[str, Any]:
217
+ """Parse a prompt line into a dictionary of argument overrides
218
+
219
+ Args:
220
+ line: Prompt line with options
221
+
222
+ Returns:
223
+ Dict[str, Any]: Dictionary of argument overrides
224
+ """
225
+ # TODO common function with hv_train_network.line_to_prompt_dict
226
+ parts = line.split(" --")
227
+ prompt = parts[0].strip()
228
+
229
+ # Create dictionary of overrides
230
+ overrides = {"prompt": prompt}
231
+
232
+ for part in parts[1:]:
233
+ if not part.strip():
234
+ continue
235
+ option_parts = part.split(" ", 1)
236
+ option = option_parts[0].strip()
237
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
238
+
239
+ # Map options to argument names
240
+ if option == "w":
241
+ overrides["video_size_width"] = int(value)
242
+ elif option == "h":
243
+ overrides["video_size_height"] = int(value)
244
+ elif option == "f":
245
+ overrides["video_length"] = int(value)
246
+ elif option == "d":
247
+ overrides["seed"] = int(value)
248
+ elif option == "s":
249
+ overrides["infer_steps"] = int(value)
250
+ elif option == "g" or option == "l":
251
+ overrides["guidance_scale"] = float(value)
252
+ elif option == "fs":
253
+ overrides["flow_shift"] = float(value)
254
+ elif option == "i":
255
+ overrides["image_path"] = value
256
+ elif option == "cn":
257
+ overrides["control_path"] = value
258
+ elif option == "n":
259
+ overrides["negative_prompt"] = value
260
+
261
+ return overrides
262
+
263
+
264
+ def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
265
+ """Apply overrides to args
266
+
267
+ Args:
268
+ args: Original arguments
269
+ overrides: Dictionary of overrides
270
+
271
+ Returns:
272
+ argparse.Namespace: New arguments with overrides applied
273
+ """
274
+ args_copy = copy.deepcopy(args)
275
+
276
+ for key, value in overrides.items():
277
+ if key == "video_size_width":
278
+ args_copy.video_size[1] = value
279
+ elif key == "video_size_height":
280
+ args_copy.video_size[0] = value
281
+ else:
282
+ setattr(args_copy, key, value)
283
+
284
+ return args_copy
285
+
286
+
287
+ def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]:
288
+ """Return default values for each task
289
+
290
+ Args:
291
+ task: task name (t2v, t2i, i2v etc.)
292
+ size: size of the video (width, height)
293
+
294
+ Returns:
295
+ Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip)
296
+ """
297
+ width, height = size if size else (0, 0)
298
+
299
+ if "t2i" in task:
300
+ return 50, 5.0, 1, False
301
+ elif "i2v" in task:
302
+ flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0
303
+ return 40, flow_shift, 81, True
304
+ else: # t2v or default
305
+ return 50, 5.0, 81, False
306
+
307
+
308
+ def setup_args(args: argparse.Namespace) -> argparse.Namespace:
309
+ """Validate and set default values for optional arguments
310
+
311
+ Args:
312
+ args: command line arguments
313
+
314
+ Returns:
315
+ argparse.Namespace: updated arguments
316
+ """
317
+ # Get default values for the task
318
+ infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size))
319
+
320
+ # Apply default values to unset arguments
321
+ if args.infer_steps is None:
322
+ args.infer_steps = infer_steps
323
+ if args.flow_shift is None:
324
+ args.flow_shift = flow_shift
325
+ if args.video_length is None:
326
+ args.video_length = video_length
327
+
328
+ # Force video_length to 1 for t2i tasks
329
+ if "t2i" in args.task:
330
+ assert args.video_length == 1, f"video_length should be 1 for task {args.task}"
331
+
332
+ # parse slg_layers
333
+ if args.slg_layers is not None:
334
+ args.slg_layers = list(map(int, args.slg_layers.split(",")))
335
+
336
+ return args
337
+
338
+
339
+ def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
340
+ """Validate video size and length
341
+
342
+ Args:
343
+ args: command line arguments
344
+
345
+ Returns:
346
+ Tuple[int, int, int]: (height, width, video_length)
347
+ """
348
+ height = args.video_size[0]
349
+ width = args.video_size[1]
350
+ size = f"{width}*{height}"
351
+
352
+ if size not in SUPPORTED_SIZES[args.task]:
353
+ logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.")
354
+
355
+ video_length = args.video_length
356
+
357
+ if height % 8 != 0 or width % 8 != 0:
358
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
359
+
360
+ return height, width, video_length
361
+
362
+
363
+ def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]:
364
+ """calculate dimensions for the generation
365
+
366
+ Args:
367
+ video_size: video frame size (height, width)
368
+ video_length: number of frames in the video
369
+ config: model configuration
370
+
371
+ Returns:
372
+ Tuple[Tuple[int, int, int, int], int]:
373
+ ((channels, frames, height, width), seq_len)
374
+ """
375
+ height, width = video_size
376
+ frames = video_length
377
+
378
+ # calculate latent space dimensions
379
+ lat_f = (frames - 1) // config.vae_stride[0] + 1
380
+ lat_h = height // config.vae_stride[1]
381
+ lat_w = width // config.vae_stride[2]
382
+
383
+ # calculate sequence length
384
+ seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f)
385
+
386
+ return ((16, lat_f, lat_h, lat_w), seq_len)
387
+
388
+
389
+ def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE:
390
+ """load VAE model
391
+
392
+ Args:
393
+ args: command line arguments
394
+ config: model configuration
395
+ device: device to use
396
+ dtype: data type for the model
397
+
398
+ Returns:
399
+ WanVAE: loaded VAE model
400
+ """
401
+ vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint)
402
+
403
+ logger.info(f"Loading VAE model from {vae_path}")
404
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
405
+ vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device)
406
+ return vae
407
+
408
+
409
+ def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel:
410
+ """load text encoder (T5) model
411
+
412
+ Args:
413
+ args: command line arguments
414
+ config: model configuration
415
+ device: device to use
416
+
417
+ Returns:
418
+ T5EncoderModel: loaded text encoder model
419
+ """
420
+ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint)
421
+ tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer)
422
+
423
+ text_encoder = T5EncoderModel(
424
+ text_len=config.text_len,
425
+ dtype=config.t5_dtype,
426
+ device=device,
427
+ checkpoint_path=checkpoint_path,
428
+ tokenizer_path=tokenizer_path,
429
+ weight_path=args.t5,
430
+ fp8=args.fp8_t5,
431
+ )
432
+
433
+ return text_encoder
434
+
435
+
436
+ def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel:
437
+ """load CLIP model (for I2V only)
438
+
439
+ Args:
440
+ args: command line arguments
441
+ config: model configuration
442
+ device: device to use
443
+
444
+ Returns:
445
+ CLIPModel: loaded CLIP model
446
+ """
447
+ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint)
448
+ tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer)
449
+
450
+ clip = CLIPModel(
451
+ dtype=config.clip_dtype,
452
+ device=device,
453
+ checkpoint_path=checkpoint_path,
454
+ tokenizer_path=tokenizer_path,
455
+ weight_path=args.clip,
456
+ )
457
+
458
+ return clip
459
+
460
+
461
+ def load_dit_model(
462
+ args: argparse.Namespace,
463
+ config,
464
+ device: torch.device,
465
+ dit_dtype: torch.dtype,
466
+ dit_weight_dtype: Optional[torch.dtype] = None,
467
+ is_i2v: bool = False,
468
+ ) -> WanModel:
469
+ """load DiT model
470
+
471
+ Args:
472
+ args: command line arguments
473
+ config: model configuration
474
+ device: device to use
475
+ dit_dtype: data type for the model
476
+ dit_weight_dtype: data type for the model weights. None for as-is
477
+ is_i2v: I2V mode
478
+
479
+ Returns:
480
+ WanModel: loaded DiT model
481
+ """
482
+ loading_device = "cpu"
483
+ if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled:
484
+ loading_device = device
485
+
486
+ loading_weight_dtype = dit_weight_dtype
487
+ if args.fp8_scaled or args.lora_weight is not None:
488
+ loading_weight_dtype = dit_dtype # load as-is
489
+
490
+ # do not fp8 optimize because we will merge LoRA weights
491
+ model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False)
492
+
493
+ return model
494
+
495
+
496
+ def merge_lora_weights(
497
+ lora_module: ModuleType,
498
+ model: torch.nn.Module,
499
+ args: argparse.Namespace,
500
+ device: torch.device,
501
+ converter: Optional[callable] = None,
502
+ ) -> None:
503
+ """merge LoRA weights to the model
504
+
505
+ Args:
506
+ lora_module: LoRA module, e.g. lora_wan
507
+ model: DiT model
508
+ args: command line arguments
509
+ device: device to use
510
+ converter: Optional callable to convert weights
511
+ """
512
+ if args.lora_weight is None or len(args.lora_weight) == 0:
513
+ return
514
+
515
+ for i, lora_weight in enumerate(args.lora_weight):
516
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
517
+ lora_multiplier = args.lora_multiplier[i]
518
+ else:
519
+ lora_multiplier = 1.0
520
+
521
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
522
+ weights_sd = load_file(lora_weight)
523
+ if converter is not None:
524
+ weights_sd = converter(weights_sd)
525
+
526
+ # apply include/exclude patterns
527
+ original_key_count = len(weights_sd.keys())
528
+ if args.include_patterns is not None and len(args.include_patterns) > i:
529
+ include_pattern = args.include_patterns[i]
530
+ regex_include = re.compile(include_pattern)
531
+ weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
532
+ logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
533
+ if args.exclude_patterns is not None and len(args.exclude_patterns) > i:
534
+ original_key_count_ex = len(weights_sd.keys())
535
+ exclude_pattern = args.exclude_patterns[i]
536
+ regex_exclude = re.compile(exclude_pattern)
537
+ weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
538
+ logger.info(
539
+ f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}"
540
+ )
541
+ if len(weights_sd) != original_key_count:
542
+ remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
543
+ remaining_keys.sort()
544
+ logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
545
+ if len(weights_sd) == 0:
546
+ logger.warning(f"No keys left after filtering.")
547
+
548
+ if args.lycoris:
549
+ lycoris_net, _ = create_network_from_weights(
550
+ multiplier=lora_multiplier,
551
+ file=None,
552
+ weights_sd=weights_sd,
553
+ unet=model,
554
+ text_encoder=None,
555
+ vae=None,
556
+ for_inference=True,
557
+ )
558
+ lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device)
559
+ else:
560
+ network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True)
561
+ network.merge_to(None, model, weights_sd, device=device, non_blocking=True)
562
+
563
+ synchronize_device(device)
564
+ logger.info("LoRA weights loaded")
565
+
566
+ # save model here before casting to dit_weight_dtype
567
+ if args.save_merged_model:
568
+ logger.info(f"Saving merged model to {args.save_merged_model}")
569
+ mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory
570
+ logger.info("Merged model saved")
571
+
572
+
573
+ def optimize_model(
574
+ model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype
575
+ ) -> None:
576
+ """optimize the model (FP8 conversion, device move etc.)
577
+
578
+ Args:
579
+ model: dit model
580
+ args: command line arguments
581
+ device: device to use
582
+ dit_dtype: dtype for the model
583
+ dit_weight_dtype: dtype for the model weights
584
+ """
585
+ if args.fp8_scaled:
586
+ # load state dict as-is and optimize to fp8
587
+ state_dict = model.state_dict()
588
+
589
+ # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
590
+ move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
591
+ state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
592
+
593
+ info = model.load_state_dict(state_dict, strict=True, assign=True)
594
+ logger.info(f"Loaded FP8 optimized weights: {info}")
595
+
596
+ if args.blocks_to_swap == 0:
597
+ model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
598
+ else:
599
+ # simple cast to dit_dtype
600
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
601
+ target_device = None
602
+
603
+ if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
604
+ logger.info(f"Convert model to {dit_weight_dtype}")
605
+ target_dtype = dit_weight_dtype
606
+
607
+ if args.blocks_to_swap == 0:
608
+ logger.info(f"Move model to device: {device}")
609
+ target_device = device
610
+
611
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
612
+
613
+ if args.compile:
614
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
615
+ logger.info(
616
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
617
+ )
618
+ torch._dynamo.config.cache_size_limit = 32
619
+ for i in range(len(model.blocks)):
620
+ model.blocks[i] = torch.compile(
621
+ model.blocks[i],
622
+ backend=compile_backend,
623
+ mode=compile_mode,
624
+ dynamic=compile_dynamic.lower() in "true",
625
+ fullgraph=compile_fullgraph.lower() in "true",
626
+ )
627
+
628
+ if args.blocks_to_swap > 0:
629
+ logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
630
+ model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
631
+ model.move_to_device_except_swap_blocks(device)
632
+ model.prepare_block_swap_before_forward()
633
+ else:
634
+ # make sure the model is on the right device
635
+ model.to(device)
636
+
637
+ model.eval().requires_grad_(False)
638
+ clean_memory_on_device(device)
639
+
640
+
641
+ def prepare_t2v_inputs(
642
+ args: argparse.Namespace,
643
+ config,
644
+ accelerator: Accelerator,
645
+ device: torch.device,
646
+ vae: Optional[WanVAE] = None,
647
+ encoded_context: Optional[Dict] = None,
648
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
649
+ """Prepare inputs for T2V
650
+
651
+ Args:
652
+ args: command line arguments
653
+ config: model configuration
654
+ accelerator: Accelerator instance
655
+ device: device to use
656
+ vae: VAE model for control video encoding
657
+ encoded_context: Pre-encoded text context
658
+
659
+ Returns:
660
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
661
+ (noise, context, context_null, (arg_c, arg_null))
662
+ """
663
+ # Prepare inputs for T2V
664
+ # calculate dimensions and sequence length
665
+ height, width = args.video_size
666
+ frames = args.video_length
667
+ (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config)
668
+ target_shape = (16, lat_f, lat_h, lat_w)
669
+
670
+ # configure negative prompt
671
+ n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
672
+
673
+ # set seed
674
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
675
+ if not args.cpu_noise:
676
+ seed_g = torch.Generator(device=device)
677
+ seed_g.manual_seed(seed)
678
+ else:
679
+ # ComfyUI compatible noise
680
+ seed_g = torch.manual_seed(seed)
681
+
682
+ if encoded_context is None:
683
+ # load text encoder
684
+ text_encoder = load_text_encoder(args, config, device)
685
+ text_encoder.model.to(device)
686
+
687
+ # encode prompt
688
+ with torch.no_grad():
689
+ if args.fp8_t5:
690
+ with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
691
+ context = text_encoder([args.prompt], device)
692
+ context_null = text_encoder([n_prompt], device)
693
+ else:
694
+ context = text_encoder([args.prompt], device)
695
+ context_null = text_encoder([n_prompt], device)
696
+
697
+ # free text encoder and clean memory
698
+ del text_encoder
699
+ clean_memory_on_device(device)
700
+ else:
701
+ # Use pre-encoded context
702
+ context = encoded_context["context"]
703
+ context_null = encoded_context["context_null"]
704
+
705
+ # Fun-Control: encode control video to latent space
706
+ if config.is_fun_control:
707
+ # TODO use same resizing as for image
708
+ logger.info(f"Encoding control video to latent space")
709
+ # C, F, H, W
710
+ control_video = load_control_video(args.control_path, frames, height, width).to(device)
711
+ vae.to_device(device)
712
+ with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
713
+ control_latent = vae.encode([control_video])[0]
714
+ y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent
715
+ vae.to_device("cpu")
716
+ else:
717
+ y = None
718
+
719
+ # generate noise
720
+ noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu")
721
+ noise = noise.to(device)
722
+
723
+ # prepare model input arguments
724
+ arg_c = {"context": context, "seq_len": seq_len}
725
+ arg_null = {"context": context_null, "seq_len": seq_len}
726
+ if y is not None:
727
+ arg_c["y"] = [y]
728
+ arg_null["y"] = [y]
729
+
730
+ return noise, context, context_null, (arg_c, arg_null)
731
+
732
+
733
+ def prepare_i2v_inputs(
734
+ args: argparse.Namespace,
735
+ config,
736
+ accelerator: Accelerator,
737
+ device: torch.device,
738
+ vae: WanVAE,
739
+ encoded_context: Optional[Dict] = None,
740
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
741
+ """Prepare inputs for I2V
742
+
743
+ Args:
744
+ args: command line arguments
745
+ config: model configuration
746
+ accelerator: Accelerator instance
747
+ device: device to use
748
+ vae: VAE model, used for image encoding
749
+ encoded_context: Pre-encoded text context
750
+
751
+ Returns:
752
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
753
+ (noise, context, context_null, y, (arg_c, arg_null))
754
+ """
755
+ # get video dimensions
756
+ height, width = args.video_size
757
+ frames = args.video_length
758
+ max_area = width * height
759
+
760
+ # load image
761
+ img = Image.open(args.image_path).convert("RGB")
762
+
763
+ # convert to numpy
764
+ img_cv2 = np.array(img) # PIL to numpy
765
+
766
+ # convert to tensor (-1 to 1)
767
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
768
+
769
+ # end frame image
770
+ if args.end_image_path is not None:
771
+ end_img = Image.open(args.end_image_path).convert("RGB")
772
+ end_img_cv2 = np.array(end_img) # PIL to numpy
773
+ else:
774
+ end_img = None
775
+ end_img_cv2 = None
776
+ has_end_image = end_img is not None
777
+
778
+ # calculate latent dimensions: keep aspect ratio
779
+ height, width = img_tensor.shape[1:]
780
+ aspect_ratio = height / width
781
+ lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
782
+ lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
783
+ height = lat_h * config.vae_stride[1]
784
+ width = lat_w * config.vae_stride[2]
785
+ lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames
786
+ max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2])
787
+
788
+ # set seed
789
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
790
+ if not args.cpu_noise:
791
+ seed_g = torch.Generator(device=device)
792
+ seed_g.manual_seed(seed)
793
+ else:
794
+ # ComfyUI compatible noise
795
+ seed_g = torch.manual_seed(seed)
796
+
797
+ # generate noise
798
+ noise = torch.randn(
799
+ 16,
800
+ lat_f + (1 if has_end_image else 0),
801
+ lat_h,
802
+ lat_w,
803
+ dtype=torch.float32,
804
+ generator=seed_g,
805
+ device=device if not args.cpu_noise else "cpu",
806
+ )
807
+ noise = noise.to(device)
808
+
809
+ # configure negative prompt
810
+ n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
811
+
812
+ if encoded_context is None:
813
+ # load text encoder
814
+ text_encoder = load_text_encoder(args, config, device)
815
+ text_encoder.model.to(device)
816
+
817
+ # encode prompt
818
+ with torch.no_grad():
819
+ if args.fp8_t5:
820
+ with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
821
+ context = text_encoder([args.prompt], device)
822
+ context_null = text_encoder([n_prompt], device)
823
+ else:
824
+ context = text_encoder([args.prompt], device)
825
+ context_null = text_encoder([n_prompt], device)
826
+
827
+ # free text encoder and clean memory
828
+ del text_encoder
829
+ clean_memory_on_device(device)
830
+
831
+ # load CLIP model
832
+ clip = load_clip_model(args, config, device)
833
+ clip.model.to(device)
834
+
835
+ # encode image to CLIP context
836
+ logger.info(f"Encoding image to CLIP context")
837
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
838
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
839
+ logger.info(f"Encoding complete")
840
+
841
+ # free CLIP model and clean memory
842
+ del clip
843
+ clean_memory_on_device(device)
844
+ else:
845
+ # Use pre-encoded context
846
+ context = encoded_context["context"]
847
+ context_null = encoded_context["context_null"]
848
+ clip_context = encoded_context["clip_context"]
849
+
850
+ # encode image to latent space with VAE
851
+ logger.info(f"Encoding image to latent space")
852
+ vae.to_device(device)
853
+
854
+ # resize image
855
+ interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC
856
+ img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation)
857
+ img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
858
+ img_resized = img_resized.unsqueeze(1) # CFHW
859
+
860
+ if has_end_image:
861
+ interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC
862
+ end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation)
863
+ end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
864
+ end_img_resized = end_img_resized.unsqueeze(1) # CFHW
865
+
866
+ # create mask for the first frame
867
+ msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device)
868
+ msk[:, 0] = 1
869
+ if has_end_image:
870
+ msk[:, -1] = 1
871
+
872
+ # encode image to latent space
873
+ with accelerator.autocast(), torch.no_grad():
874
+ # padding to match the required number of frames
875
+ padding_frames = frames - 1 # the first frame is image
876
+ img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1)
877
+ y = vae.encode([img_resized])[0]
878
+
879
+ if has_end_image:
880
+ y_end = vae.encode([end_img_resized])[0]
881
+ y = torch.concat([y, y_end], dim=1) # add end frame
882
+
883
+ y = torch.concat([msk, y])
884
+ logger.info(f"Encoding complete")
885
+
886
+ # Fun-Control: encode control video to latent space
887
+ if config.is_fun_control:
888
+ # TODO use same resizing as for image
889
+ logger.info(f"Encoding control video to latent space")
890
+ # C, F, H, W
891
+ control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device)
892
+ with accelerator.autocast(), torch.no_grad():
893
+ control_latent = vae.encode([control_video])[0]
894
+ y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it
895
+ if has_end_image:
896
+ y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work
897
+ else:
898
+ y[:, 1:] = 0 # remove image latent except first frame
899
+ y = torch.concat([control_latent, y], dim=0) # add control video latent
900
+
901
+ # prepare model input arguments
902
+ arg_c = {
903
+ "context": [context[0]],
904
+ "clip_fea": clip_context,
905
+ "seq_len": max_seq_len,
906
+ "y": [y],
907
+ }
908
+
909
+ arg_null = {
910
+ "context": context_null,
911
+ "clip_fea": clip_context,
912
+ "seq_len": max_seq_len,
913
+ "y": [y],
914
+ }
915
+
916
+ vae.to_device("cpu") # move VAE to CPU to save memory
917
+ clean_memory_on_device(device)
918
+
919
+ return noise, context, context_null, y, (arg_c, arg_null)
920
+
921
+
922
+ def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor:
923
+ """load control video to latent space
924
+
925
+ Args:
926
+ control_path: path to control video
927
+ frames: number of frames in the video
928
+ height: height of the video
929
+ width: width of the video
930
+
931
+ Returns:
932
+ torch.Tensor: control video latent, CFHW
933
+ """
934
+ logger.info(f"Load control video from {control_path}")
935
+ video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames
936
+ if len(video) < frames:
937
+ raise ValueError(f"Video length is less than {frames}")
938
+ # video = np.stack(video, axis=0) # F, H, W, C
939
+ video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1
940
+ video = video.permute(1, 0, 2, 3) # C, F, H, W
941
+ return video
942
+
943
+
944
+ def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
945
+ """setup scheduler for sampling
946
+
947
+ Args:
948
+ args: command line arguments
949
+ config: model configuration
950
+ device: device to use
951
+
952
+ Returns:
953
+ Tuple[Any, torch.Tensor]: (scheduler, timesteps)
954
+ """
955
+ if args.sample_solver == "unipc":
956
+ scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
957
+ scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
958
+ timesteps = scheduler.timesteps
959
+ elif args.sample_solver == "dpm++":
960
+ scheduler = FlowDPMSolverMultistepScheduler(
961
+ num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
962
+ )
963
+ sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
964
+ timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
965
+ elif args.sample_solver == "vanilla":
966
+ scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
967
+ scheduler.set_timesteps(args.infer_steps, device=device)
968
+ timesteps = scheduler.timesteps
969
+
970
+ # FlowMatchDiscreteScheduler does not support generator argument in step method
971
+ org_step = scheduler.step
972
+
973
+ def step_wrapper(
974
+ model_output: torch.Tensor,
975
+ timestep: Union[int, torch.Tensor],
976
+ sample: torch.Tensor,
977
+ return_dict: bool = True,
978
+ generator=None,
979
+ ):
980
+ return org_step(model_output, timestep, sample, return_dict=return_dict)
981
+
982
+ scheduler.step = step_wrapper
983
+ else:
984
+ raise NotImplementedError("Unsupported solver.")
985
+
986
+ return scheduler, timesteps
987
+
988
+
989
+ def run_sampling(
990
+ model: WanModel,
991
+ noise: torch.Tensor,
992
+ scheduler: Any,
993
+ timesteps: torch.Tensor,
994
+ args: argparse.Namespace,
995
+ inputs: Tuple[dict, dict],
996
+ device: torch.device,
997
+ seed_g: torch.Generator,
998
+ accelerator: Accelerator,
999
+ is_i2v: bool = False,
1000
+ use_cpu_offload: bool = True,
1001
+ ) -> torch.Tensor:
1002
+ """run sampling
1003
+ Args:
1004
+ model: dit model
1005
+ noise: initial noise
1006
+ scheduler: scheduler for sampling
1007
+ timesteps: time steps for sampling
1008
+ args: command line arguments
1009
+ inputs: model input (arg_c, arg_null)
1010
+ device: device to use
1011
+ seed_g: random generator
1012
+ accelerator: Accelerator instance
1013
+ is_i2v: I2V mode (False means T2V mode)
1014
+ use_cpu_offload: Whether to offload tensors to CPU during processing
1015
+ Returns:
1016
+ torch.Tensor: generated latent
1017
+ """
1018
+ arg_c, arg_null = inputs
1019
+
1020
+ latent = noise
1021
+ latent_storage_device = device if not use_cpu_offload else "cpu"
1022
+ latent = latent.to(latent_storage_device)
1023
+
1024
+ # cfg skip
1025
+ apply_cfg_array = []
1026
+ num_timesteps = len(timesteps)
1027
+
1028
+ if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None:
1029
+ # Calculate thresholds based on cfg_apply_ratio
1030
+ apply_steps = int(num_timesteps * args.cfg_apply_ratio)
1031
+
1032
+ if args.cfg_skip_mode == "early":
1033
+ # Skip CFG in early steps, apply in late steps
1034
+ start_index = num_timesteps - apply_steps
1035
+ end_index = num_timesteps
1036
+ elif args.cfg_skip_mode == "late":
1037
+ # Skip CFG in late steps, apply in early steps
1038
+ start_index = 0
1039
+ end_index = apply_steps
1040
+ elif args.cfg_skip_mode == "early_late":
1041
+ # Skip CFG in early and late steps, apply in middle steps
1042
+ start_index = (num_timesteps - apply_steps) // 2
1043
+ end_index = start_index + apply_steps
1044
+ elif args.cfg_skip_mode == "middle":
1045
+ # Skip CFG in middle steps, apply in early and late steps
1046
+ skip_steps = num_timesteps - apply_steps
1047
+ middle_start = (num_timesteps - skip_steps) // 2
1048
+ middle_end = middle_start + skip_steps
1049
+
1050
+ w = 0.0
1051
+ for step_idx in range(num_timesteps):
1052
+ if args.cfg_skip_mode == "alternate":
1053
+ # accumulate w and apply CFG when w >= 1.0
1054
+ w += args.cfg_apply_ratio
1055
+ apply = w >= 1.0
1056
+ if apply:
1057
+ w -= 1.0
1058
+ elif args.cfg_skip_mode == "middle":
1059
+ # Skip CFG in early and late steps, apply in middle steps
1060
+ apply = step_idx < middle_start or step_idx >= middle_end
1061
+ else:
1062
+ # Apply CFG on some steps based on ratio
1063
+ apply = step_idx >= start_index and step_idx < end_index
1064
+
1065
+ apply_cfg_array.append(apply)
1066
+
1067
+ pattern = ["A" if apply else "S" for apply in apply_cfg_array]
1068
+ pattern = "".join(pattern)
1069
+ logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}")
1070
+ else:
1071
+ # Apply CFG on all steps
1072
+ apply_cfg_array = [True] * num_timesteps
1073
+
1074
+ # SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py
1075
+ slg_start_step = int(args.slg_start * num_timesteps)
1076
+ slg_end_step = int(args.slg_end * num_timesteps)
1077
+
1078
+ for i, t in enumerate(tqdm(timesteps)):
1079
+ # latent is on CPU if use_cpu_offload is True
1080
+ latent_model_input = [latent.to(device)]
1081
+ timestep = torch.stack([t]).to(device)
1082
+
1083
+ with accelerator.autocast(), torch.no_grad():
1084
+ noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device)
1085
+
1086
+ apply_cfg = apply_cfg_array[i] # apply CFG or not
1087
+ if apply_cfg:
1088
+ apply_slg = i >= slg_start_step and i < slg_end_step
1089
+ # print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}")
1090
+ if args.slg_mode == "original" and apply_slg:
1091
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
1092
+
1093
+ # apply guidance
1094
+ # SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale
1095
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
1096
+
1097
+ # calculate skip layer out
1098
+ skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
1099
+ latent_storage_device
1100
+ )
1101
+
1102
+ # apply skip layer guidance
1103
+ # SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg
1104
+ noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out)
1105
+ elif args.slg_mode == "uncond" and apply_slg:
1106
+ # noise_pred_uncond is skip layer out
1107
+ noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
1108
+ latent_storage_device
1109
+ )
1110
+
1111
+ # apply guidance
1112
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
1113
+
1114
+ else:
1115
+ # normal guidance
1116
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
1117
+
1118
+ # apply guidance
1119
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
1120
+ else:
1121
+ noise_pred = noise_pred_cond
1122
+
1123
+ # step
1124
+ latent_input = latent.unsqueeze(0)
1125
+ temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0]
1126
+
1127
+ # update latent
1128
+ latent = temp_x0.squeeze(0)
1129
+
1130
+ return latent
1131
+
1132
+
1133
+ def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor:
1134
+ """main function for generation
1135
+
1136
+ Args:
1137
+ args: command line arguments
1138
+ shared_models: dictionary containing pre-loaded models and encoded data
1139
+
1140
+ Returns:
1141
+ torch.Tensor: generated latent
1142
+ """
1143
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
1144
+ gen_settings.device,
1145
+ gen_settings.cfg,
1146
+ gen_settings.dit_dtype,
1147
+ gen_settings.dit_weight_dtype,
1148
+ gen_settings.vae_dtype,
1149
+ )
1150
+
1151
+ # prepare accelerator
1152
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
1153
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
1154
+
1155
+ # I2V or T2V
1156
+ is_i2v = "i2v" in args.task
1157
+
1158
+ # prepare seed
1159
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
1160
+ args.seed = seed # set seed to args for saving
1161
+
1162
+ # Check if we have shared models
1163
+ if shared_models is not None:
1164
+ # Use shared models and encoded data
1165
+ vae = shared_models.get("vae")
1166
+ model = shared_models.get("model")
1167
+ encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt)
1168
+
1169
+ # prepare inputs
1170
+ if is_i2v:
1171
+ # I2V
1172
+ noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
1173
+ else:
1174
+ # T2V
1175
+ noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
1176
+ else:
1177
+ # prepare inputs without shared models
1178
+ if is_i2v:
1179
+ # I2V: need text encoder, VAE and CLIP
1180
+ vae = load_vae(args, cfg, device, vae_dtype)
1181
+ noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae)
1182
+ # vae is on CPU after prepare_i2v_inputs
1183
+ else:
1184
+ # T2V: need text encoder
1185
+ vae = None
1186
+ if cfg.is_fun_control:
1187
+ # Fun-Control: need VAE for encoding control video
1188
+ vae = load_vae(args, cfg, device, vae_dtype)
1189
+ noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae)
1190
+
1191
+ # load DiT model
1192
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
1193
+
1194
+ # merge LoRA weights
1195
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
1196
+ merge_lora_weights(lora_wan, model, args, device)
1197
+
1198
+ # if we only want to save the model, we can skip the rest
1199
+ if args.save_merged_model:
1200
+ return None
1201
+
1202
+ # optimize model: fp8 conversion, block swap etc.
1203
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
1204
+
1205
+ # setup scheduler
1206
+ scheduler, timesteps = setup_scheduler(args, cfg, device)
1207
+
1208
+ # set random generator
1209
+ seed_g = torch.Generator(device=device)
1210
+ seed_g.manual_seed(seed)
1211
+
1212
+ # run sampling
1213
+ latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v)
1214
+
1215
+ # Only clean up shared models if they were created within this function
1216
+ if shared_models is None:
1217
+ # free memory
1218
+ del model
1219
+ del scheduler
1220
+ synchronize_device(device)
1221
+
1222
+ # wait for 5 seconds until block swap is done
1223
+ logger.info("Waiting for 5 seconds to finish block swap")
1224
+ time.sleep(5)
1225
+
1226
+ gc.collect()
1227
+ clean_memory_on_device(device)
1228
+
1229
+ # save VAE model for decoding
1230
+ if vae is None:
1231
+ args._vae = None
1232
+ else:
1233
+ args._vae = vae
1234
+
1235
+ return latent
1236
+
1237
+
1238
+ def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor:
1239
+ """decode latent
1240
+
1241
+ Args:
1242
+ latent: latent tensor
1243
+ args: command line arguments
1244
+ cfg: model configuration
1245
+
1246
+ Returns:
1247
+ torch.Tensor: decoded video or image
1248
+ """
1249
+ device = torch.device(args.device)
1250
+
1251
+ # load VAE model or use the one from the generation
1252
+ vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16
1253
+ if hasattr(args, "_vae") and args._vae is not None:
1254
+ vae = args._vae
1255
+ else:
1256
+ vae = load_vae(args, cfg, device, vae_dtype)
1257
+
1258
+ vae.to_device(device)
1259
+
1260
+ logger.info(f"Decoding video from latents: {latent.shape}")
1261
+ x0 = latent.to(device)
1262
+
1263
+ with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad():
1264
+ videos = vae.decode(x0)
1265
+
1266
+ # some tail frames may be corrupted when end frame is used, we add an option to remove them
1267
+ if args.trim_tail_frames:
1268
+ videos[0] = videos[0][:, : -args.trim_tail_frames]
1269
+
1270
+ logger.info(f"Decoding complete")
1271
+ video = videos[0]
1272
+ del videos
1273
+ video = video.to(torch.float32).cpu()
1274
+
1275
+ return video
1276
+
1277
+
1278
+ def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
1279
+ """Save latent to file
1280
+
1281
+ Args:
1282
+ latent: latent tensor
1283
+ args: command line arguments
1284
+ height: height of frame
1285
+ width: width of frame
1286
+
1287
+ Returns:
1288
+ str: Path to saved latent file
1289
+ """
1290
+ save_path = args.save_path
1291
+ os.makedirs(save_path, exist_ok=True)
1292
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1293
+
1294
+ seed = args.seed
1295
+ video_length = args.video_length
1296
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
1297
+
1298
+ if args.no_metadata:
1299
+ metadata = None
1300
+ else:
1301
+ metadata = {
1302
+ "seeds": f"{seed}",
1303
+ "prompt": f"{args.prompt}",
1304
+ "height": f"{height}",
1305
+ "width": f"{width}",
1306
+ "video_length": f"{video_length}",
1307
+ "infer_steps": f"{args.infer_steps}",
1308
+ "guidance_scale": f"{args.guidance_scale}",
1309
+ }
1310
+ if args.negative_prompt is not None:
1311
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
1312
+
1313
+ sd = {"latent": latent}
1314
+ save_file(sd, latent_path, metadata=metadata)
1315
+ logger.info(f"Latent saved to: {latent_path}")
1316
+
1317
+ return latent_path
1318
+
1319
+
1320
+ def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
1321
+ """Save video to file
1322
+
1323
+ Args:
1324
+ video: Video tensor
1325
+ args: command line arguments
1326
+ original_base_name: Original base name (if latents are loaded from files)
1327
+
1328
+ Returns:
1329
+ str: Path to saved video file
1330
+ """
1331
+ save_path = args.save_path
1332
+ os.makedirs(save_path, exist_ok=True)
1333
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1334
+
1335
+ seed = args.seed
1336
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1337
+ video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4"
1338
+
1339
+ video = video.unsqueeze(0)
1340
+ save_videos_grid(video, video_path, fps=args.fps, rescale=True)
1341
+ logger.info(f"Video saved to: {video_path}")
1342
+
1343
+ return video_path
1344
+
1345
+
1346
+ def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
1347
+ """Save images to directory
1348
+
1349
+ Args:
1350
+ sample: Video tensor
1351
+ args: command line arguments
1352
+ original_base_name: Original base name (if latents are loaded from files)
1353
+
1354
+ Returns:
1355
+ str: Path to saved images directory
1356
+ """
1357
+ save_path = args.save_path
1358
+ os.makedirs(save_path, exist_ok=True)
1359
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1360
+
1361
+ seed = args.seed
1362
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1363
+ image_name = f"{time_flag}_{seed}{original_name}"
1364
+ sample = sample.unsqueeze(0)
1365
+ save_images_grid(sample, save_path, image_name, rescale=True)
1366
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
1367
+
1368
+ return f"{save_path}/{image_name}"
1369
+
1370
+
1371
+ def save_output(
1372
+ latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None
1373
+ ) -> None:
1374
+ """save output
1375
+
1376
+ Args:
1377
+ latent: latent tensor
1378
+ args: command line arguments
1379
+ cfg: model configuration
1380
+ height: height of frame
1381
+ width: width of frame
1382
+ original_base_names: original base names (if latents are loaded from files)
1383
+ """
1384
+ if args.output_type == "latent" or args.output_type == "both":
1385
+ # save latent
1386
+ save_latent(latent, args, height, width)
1387
+
1388
+ if args.output_type == "video" or args.output_type == "both":
1389
+ # save video
1390
+ sample = decode_latent(latent.unsqueeze(0), args, cfg)
1391
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1392
+ save_video(sample, args, original_name)
1393
+
1394
+ elif args.output_type == "images":
1395
+ # save images
1396
+ sample = decode_latent(latent.unsqueeze(0), args, cfg)
1397
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1398
+ save_images(sample, args, original_name)
1399
+
1400
+
1401
+ def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
1402
+ """Process multiple prompts for batch mode
1403
+
1404
+ Args:
1405
+ prompt_lines: List of prompt lines
1406
+ base_args: Base command line arguments
1407
+
1408
+ Returns:
1409
+ List[Dict]: List of prompt data dictionaries
1410
+ """
1411
+ prompts_data = []
1412
+
1413
+ for line in prompt_lines:
1414
+ line = line.strip()
1415
+ if not line or line.startswith("#"): # Skip empty lines and comments
1416
+ continue
1417
+
1418
+ # Parse prompt line and create override dictionary
1419
+ prompt_data = parse_prompt_line(line)
1420
+ logger.info(f"Parsed prompt data: {prompt_data}")
1421
+ prompts_data.append(prompt_data)
1422
+
1423
+ return prompts_data
1424
+
1425
+
1426
+ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
1427
+ """Process multiple prompts with model reuse
1428
+
1429
+ Args:
1430
+ prompts_data: List of prompt data dictionaries
1431
+ args: Base command line arguments
1432
+ """
1433
+ if not prompts_data:
1434
+ logger.warning("No valid prompts found")
1435
+ return
1436
+
1437
+ # 1. Load configuration
1438
+ gen_settings = get_generation_settings(args)
1439
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
1440
+ gen_settings.device,
1441
+ gen_settings.cfg,
1442
+ gen_settings.dit_dtype,
1443
+ gen_settings.dit_weight_dtype,
1444
+ gen_settings.vae_dtype,
1445
+ )
1446
+ is_i2v = "i2v" in args.task
1447
+
1448
+ # 2. Encode all prompts
1449
+ logger.info("Loading text encoder to encode all prompts")
1450
+ text_encoder = load_text_encoder(args, cfg, device)
1451
+ text_encoder.model.to(device)
1452
+
1453
+ encoded_contexts = {}
1454
+
1455
+ with torch.no_grad():
1456
+ for prompt_data in prompts_data:
1457
+ prompt = prompt_data["prompt"]
1458
+ prompt_args = apply_overrides(args, prompt_data)
1459
+ n_prompt = prompt_data.get(
1460
+ "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
1461
+ )
1462
+
1463
+ if args.fp8_t5:
1464
+ with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
1465
+ context = text_encoder([prompt], device)
1466
+ context_null = text_encoder([n_prompt], device)
1467
+ else:
1468
+ context = text_encoder([prompt], device)
1469
+ context_null = text_encoder([n_prompt], device)
1470
+
1471
+ encoded_contexts[prompt] = {"context": context, "context_null": context_null}
1472
+
1473
+ # Free text encoder and clean memory
1474
+ del text_encoder
1475
+ clean_memory_on_device(device)
1476
+
1477
+ # 3. Process I2V additional encodings if needed
1478
+ vae = None
1479
+ if is_i2v:
1480
+ logger.info("Loading VAE and CLIP for I2V preprocessing")
1481
+ vae = load_vae(args, cfg, device, vae_dtype)
1482
+ vae.to_device(device)
1483
+
1484
+ clip = load_clip_model(args, cfg, device)
1485
+ clip.model.to(device)
1486
+
1487
+ # Process each image and encode with CLIP
1488
+ for prompt_data in prompts_data:
1489
+ if "image_path" not in prompt_data:
1490
+ continue
1491
+
1492
+ prompt_args = apply_overrides(args, prompt_data)
1493
+ if not os.path.exists(prompt_args.image_path):
1494
+ logger.warning(f"Image path not found: {prompt_args.image_path}")
1495
+ continue
1496
+
1497
+ # Load and encode image with CLIP
1498
+ img = Image.open(prompt_args.image_path).convert("RGB")
1499
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
1500
+
1501
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
1502
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
1503
+
1504
+ encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context
1505
+
1506
+ # Free CLIP and clean memory
1507
+ del clip
1508
+ clean_memory_on_device(device)
1509
+
1510
+ # Keep VAE in CPU memory for later use
1511
+ vae.to_device("cpu")
1512
+ elif cfg.is_fun_control:
1513
+ # For Fun-Control, we need VAE but keep it on CPU
1514
+ vae = load_vae(args, cfg, device, vae_dtype)
1515
+ vae.to_device("cpu")
1516
+
1517
+ # 4. Load DiT model
1518
+ logger.info("Loading DiT model")
1519
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
1520
+
1521
+ # 5. Merge LoRA weights if needed
1522
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
1523
+ merge_lora_weights(lora_wan, model, args, device)
1524
+ if args.save_merged_model:
1525
+ logger.info("Model merged and saved. Exiting.")
1526
+ return
1527
+
1528
+ # 6. Optimize model
1529
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
1530
+
1531
+ # Create shared models dict for generate function
1532
+ shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts}
1533
+
1534
+ # 7. Generate for each prompt
1535
+ all_latents = []
1536
+ all_prompt_args = []
1537
+
1538
+ for i, prompt_data in enumerate(prompts_data):
1539
+ logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...")
1540
+
1541
+ # Apply overrides for this prompt
1542
+ prompt_args = apply_overrides(args, prompt_data)
1543
+
1544
+ # Generate latent
1545
+ latent = generate(prompt_args, gen_settings, shared_models)
1546
+
1547
+ # Save latent if needed
1548
+ height, width, _ = check_inputs(prompt_args)
1549
+ if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
1550
+ save_latent(latent, prompt_args, height, width)
1551
+
1552
+ all_latents.append(latent)
1553
+ all_prompt_args.append(prompt_args)
1554
+
1555
+ # 8. Free DiT model
1556
+ del model
1557
+ clean_memory_on_device(device)
1558
+ synchronize_device(device)
1559
+
1560
+ # wait for 5 seconds until block swap is done
1561
+ logger.info("Waiting for 5 seconds to finish block swap")
1562
+ time.sleep(5)
1563
+
1564
+ gc.collect()
1565
+ clean_memory_on_device(device)
1566
+
1567
+ # 9. Decode latents if needed
1568
+ if args.output_type != "latent":
1569
+ logger.info("Decoding latents to videos/images")
1570
+
1571
+ if vae is None:
1572
+ vae = load_vae(args, cfg, device, vae_dtype)
1573
+
1574
+ vae.to_device(device)
1575
+
1576
+ for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
1577
+ logger.info(f"Decoding output {i+1}/{len(all_latents)}")
1578
+
1579
+ # Decode latent
1580
+ video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
1581
+
1582
+ # Save as video or images
1583
+ if prompt_args.output_type == "video" or prompt_args.output_type == "both":
1584
+ save_video(video, prompt_args)
1585
+ elif prompt_args.output_type == "images":
1586
+ save_images(video, prompt_args)
1587
+
1588
+ # Free VAE
1589
+ del vae
1590
+
1591
+ clean_memory_on_device(device)
1592
+ gc.collect()
1593
+
1594
+
1595
+ def process_interactive(args: argparse.Namespace) -> None:
1596
+ """Process prompts in interactive mode
1597
+
1598
+ Args:
1599
+ args: Base command line arguments
1600
+ """
1601
+ gen_settings = get_generation_settings(args)
1602
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
1603
+ gen_settings.device,
1604
+ gen_settings.cfg,
1605
+ gen_settings.dit_dtype,
1606
+ gen_settings.dit_weight_dtype,
1607
+ gen_settings.vae_dtype,
1608
+ )
1609
+ is_i2v = "i2v" in args.task
1610
+
1611
+ # Initialize models to None
1612
+ text_encoder = None
1613
+ vae = None
1614
+ model = None
1615
+ clip = None
1616
+
1617
+ print("Interactive mode. Enter prompts (Ctrl+D to exit):")
1618
+
1619
+ try:
1620
+ while True:
1621
+ try:
1622
+ line = input("> ")
1623
+ if not line.strip():
1624
+ continue
1625
+
1626
+ # Parse prompt
1627
+ prompt_data = parse_prompt_line(line)
1628
+ prompt_args = apply_overrides(args, prompt_data)
1629
+
1630
+ # Ensure we have all the models we need
1631
+
1632
+ # 1. Load text encoder if not already loaded
1633
+ if text_encoder is None:
1634
+ logger.info("Loading text encoder")
1635
+ text_encoder = load_text_encoder(args, cfg, device)
1636
+
1637
+ text_encoder.model.to(device)
1638
+
1639
+ # Encode prompt
1640
+ n_prompt = prompt_data.get(
1641
+ "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
1642
+ )
1643
+
1644
+ with torch.no_grad():
1645
+ if args.fp8_t5:
1646
+ with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
1647
+ context = text_encoder([prompt_data["prompt"]], device)
1648
+ context_null = text_encoder([n_prompt], device)
1649
+ else:
1650
+ context = text_encoder([prompt_data["prompt"]], device)
1651
+ context_null = text_encoder([n_prompt], device)
1652
+
1653
+ encoded_context = {"context": context, "context_null": context_null}
1654
+
1655
+ # Move text encoder to CPU after use
1656
+ text_encoder.model.to("cpu")
1657
+
1658
+ # 2. For I2V, we need CLIP and VAE
1659
+ if is_i2v:
1660
+ if clip is None:
1661
+ logger.info("Loading CLIP model")
1662
+ clip = load_clip_model(args, cfg, device)
1663
+
1664
+ clip.model.to(device)
1665
+
1666
+ # Encode image with CLIP if there's an image path
1667
+ if prompt_args.image_path and os.path.exists(prompt_args.image_path):
1668
+ img = Image.open(prompt_args.image_path).convert("RGB")
1669
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
1670
+
1671
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
1672
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
1673
+
1674
+ encoded_context["clip_context"] = clip_context
1675
+
1676
+ # Move CLIP to CPU after use
1677
+ clip.model.to("cpu")
1678
+
1679
+ # Load VAE if needed
1680
+ if vae is None:
1681
+ logger.info("Loading VAE model")
1682
+ vae = load_vae(args, cfg, device, vae_dtype)
1683
+ elif cfg.is_fun_control and vae is None:
1684
+ # For Fun-Control, we need VAE
1685
+ logger.info("Loading VAE model for Fun-Control")
1686
+ vae = load_vae(args, cfg, device, vae_dtype)
1687
+
1688
+ # 3. Load DiT model if not already loaded
1689
+ if model is None:
1690
+ logger.info("Loading DiT model")
1691
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
1692
+
1693
+ # Merge LoRA weights if needed
1694
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
1695
+ merge_lora_weights(lora_wan, model, args, device)
1696
+
1697
+ # Optimize model
1698
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
1699
+ else:
1700
+ # Move model to GPU if it was offloaded
1701
+ model.to(device)
1702
+
1703
+ # Create shared models dict
1704
+ shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}}
1705
+
1706
+ # Generate latent
1707
+ latent = generate(prompt_args, gen_settings, shared_models)
1708
+
1709
+ # Move model to CPU after generation
1710
+ model.to("cpu")
1711
+
1712
+ # Save latent if needed
1713
+ height, width, _ = check_inputs(prompt_args)
1714
+ if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
1715
+ save_latent(latent, prompt_args, height, width)
1716
+
1717
+ # Decode and save output
1718
+ if prompt_args.output_type != "latent":
1719
+ if vae is None:
1720
+ vae = load_vae(args, cfg, device, vae_dtype)
1721
+
1722
+ vae.to_device(device)
1723
+ video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
1724
+
1725
+ if prompt_args.output_type == "video" or prompt_args.output_type == "both":
1726
+ save_video(video, prompt_args)
1727
+ elif prompt_args.output_type == "images":
1728
+ save_images(video, prompt_args)
1729
+
1730
+ # Move VAE to CPU after use
1731
+ vae.to_device("cpu")
1732
+
1733
+ clean_memory_on_device(device)
1734
+
1735
+ except KeyboardInterrupt:
1736
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
1737
+ continue
1738
+
1739
+ except EOFError:
1740
+ print("\nExiting interactive mode")
1741
+
1742
+ # Clean up all models
1743
+ if text_encoder is not None:
1744
+ del text_encoder
1745
+ if clip is not None:
1746
+ del clip
1747
+ if vae is not None:
1748
+ del vae
1749
+ if model is not None:
1750
+ del model
1751
+
1752
+ clean_memory_on_device(device)
1753
+ gc.collect()
1754
+
1755
+
1756
+ def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
1757
+ device = torch.device(args.device)
1758
+
1759
+ cfg = WAN_CONFIGS[args.task]
1760
+
1761
+ # select dtype
1762
+ dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16
1763
+ if dit_dtype.itemsize == 1:
1764
+ # if weight is in fp8, use bfloat16 for DiT (input/output)
1765
+ dit_dtype = torch.bfloat16
1766
+ if args.fp8_scaled:
1767
+ raise ValueError(
1768
+ "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
1769
+ )
1770
+
1771
+ dit_weight_dtype = dit_dtype # default
1772
+ if args.fp8_scaled:
1773
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
1774
+ elif args.fp8:
1775
+ dit_weight_dtype = torch.float8_e4m3fn
1776
+
1777
+ vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype
1778
+ logger.info(
1779
+ f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}"
1780
+ )
1781
+
1782
+ gen_settings = GenerationSettings(
1783
+ device=device,
1784
+ cfg=cfg,
1785
+ dit_dtype=dit_dtype,
1786
+ dit_weight_dtype=dit_weight_dtype,
1787
+ vae_dtype=vae_dtype,
1788
+ )
1789
+ return gen_settings
1790
+
1791
+
1792
+ def main():
1793
+ # Parse arguments
1794
+ args = parse_args()
1795
+
1796
+ # Check if latents are provided
1797
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
1798
+
1799
+ # Set device
1800
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
1801
+ device = torch.device(device)
1802
+ logger.info(f"Using device: {device}")
1803
+ args.device = device
1804
+
1805
+ if latents_mode:
1806
+ # Original latent decode mode
1807
+ cfg = WAN_CONFIGS[args.task] # any task is fine
1808
+ original_base_names = []
1809
+ latents_list = []
1810
+ seeds = []
1811
+
1812
+ assert len(args.latent_path) == 1, "Only one latent path is supported for now"
1813
+
1814
+ for latent_path in args.latent_path:
1815
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
1816
+ seed = 0
1817
+
1818
+ if os.path.splitext(latent_path)[1] != ".safetensors":
1819
+ latents = torch.load(latent_path, map_location="cpu")
1820
+ else:
1821
+ latents = load_file(latent_path)["latent"]
1822
+ with safe_open(latent_path, framework="pt") as f:
1823
+ metadata = f.metadata()
1824
+ if metadata is None:
1825
+ metadata = {}
1826
+ logger.info(f"Loaded metadata: {metadata}")
1827
+
1828
+ if "seeds" in metadata:
1829
+ seed = int(metadata["seeds"])
1830
+ if "height" in metadata and "width" in metadata:
1831
+ height = int(metadata["height"])
1832
+ width = int(metadata["width"])
1833
+ args.video_size = [height, width]
1834
+ if "video_length" in metadata:
1835
+ args.video_length = int(metadata["video_length"])
1836
+
1837
+ seeds.append(seed)
1838
+ latents_list.append(latents)
1839
+
1840
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
1841
+
1842
+ latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
1843
+
1844
+ height = latents.shape[-2]
1845
+ width = latents.shape[-1]
1846
+ height *= cfg.patch_size[1] * cfg.vae_stride[1]
1847
+ width *= cfg.patch_size[2] * cfg.vae_stride[2]
1848
+ video_length = latents.shape[1]
1849
+ video_length = (video_length - 1) * cfg.vae_stride[0] + 1
1850
+ args.seed = seeds[0]
1851
+
1852
+ # Decode and save
1853
+ save_output(latent[0], args, cfg, height, width, original_base_names)
1854
+
1855
+ elif args.from_file:
1856
+ # Batch mode from file
1857
+ args = setup_args(args)
1858
+
1859
+ # Read prompts from file
1860
+ with open(args.from_file, "r", encoding="utf-8") as f:
1861
+ prompt_lines = f.readlines()
1862
+
1863
+ # Process prompts
1864
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
1865
+ process_batch_prompts(prompts_data, args)
1866
+
1867
+ elif args.interactive:
1868
+ # Interactive mode
1869
+ args = setup_args(args)
1870
+ process_interactive(args)
1871
+
1872
+ else:
1873
+ # Single prompt mode (original behavior)
1874
+ args = setup_args(args)
1875
+ height, width, video_length = check_inputs(args)
1876
+
1877
+ logger.info(
1878
+ f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, "
1879
+ f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}"
1880
+ )
1881
+
1882
+ # Generate latent
1883
+ gen_settings = get_generation_settings(args)
1884
+ latent = generate(args, gen_settings)
1885
+
1886
+ # Make sure the model is freed from GPU memory
1887
+ gc.collect()
1888
+ clean_memory_on_device(args.device)
1889
+
1890
+ # Save latent and video
1891
+ if args.save_merged_model:
1892
+ return
1893
+
1894
+ # Add batch dimension
1895
+ latent = latent.unsqueeze(0)
1896
+ save_output(latent[0], args, WAN_CONFIGS[args.task], height, width)
1897
+
1898
+ logger.info("Done!")
1899
+
1900
+
1901
+ if __name__ == "__main__":
1902
+ main()
wan_train_network.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Optional
3
+ from PIL import Image
4
+
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision.transforms.functional as TF
9
+ from tqdm import tqdm
10
+ from accelerate import Accelerator, init_empty_weights
11
+
12
+ from dataset.image_video_dataset import ARCHITECTURE_WAN, ARCHITECTURE_WAN_FULL, load_video
13
+ from hv_generate_video import resize_image_to_bucket
14
+ from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
15
+
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+ logging.basicConfig(level=logging.INFO)
20
+
21
+ from utils import model_utils
22
+ from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
23
+ from wan.configs import WAN_CONFIGS
24
+ from wan.modules.clip import CLIPModel
25
+ from wan.modules.model import WanModel, detect_wan_sd_dtype, load_wan_model
26
+ from wan.modules.t5 import T5EncoderModel
27
+ from wan.modules.vae import WanVAE
28
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
+
30
+
31
+ class WanNetworkTrainer(NetworkTrainer):
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ # region model specific
36
+
37
+ @property
38
+ def architecture(self) -> str:
39
+ return ARCHITECTURE_WAN
40
+
41
+ @property
42
+ def architecture_full_name(self) -> str:
43
+ return ARCHITECTURE_WAN_FULL
44
+
45
+ def handle_model_specific_args(self, args):
46
+ self.config = WAN_CONFIGS[args.task]
47
+ self._i2v_training = "i2v" in args.task # we cannot use config.i2v because Fun-Control T2V has i2v flag TODO refactor this
48
+ self._control_training = self.config.is_fun_control
49
+
50
+ self.dit_dtype = detect_wan_sd_dtype(args.dit)
51
+
52
+ if self.dit_dtype == torch.float16:
53
+ assert args.mixed_precision in ["fp16", "no"], "DiT weights are in fp16, mixed precision must be fp16 or no"
54
+ elif self.dit_dtype == torch.bfloat16:
55
+ assert args.mixed_precision in ["bf16", "no"], "DiT weights are in bf16, mixed precision must be bf16 or no"
56
+
57
+ if args.fp8_scaled and self.dit_dtype.itemsize == 1:
58
+ raise ValueError(
59
+ "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
60
+ )
61
+
62
+ # dit_dtype cannot be fp8, so we select the appropriate dtype
63
+ if self.dit_dtype.itemsize == 1:
64
+ self.dit_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
65
+
66
+ args.dit_dtype = model_utils.dtype_to_str(self.dit_dtype)
67
+
68
+ self.default_guidance_scale = 1.0 # not used
69
+
70
+ def process_sample_prompts(
71
+ self,
72
+ args: argparse.Namespace,
73
+ accelerator: Accelerator,
74
+ sample_prompts: str,
75
+ ):
76
+ config = self.config
77
+ device = accelerator.device
78
+ t5_path, clip_path, fp8_t5 = args.t5, args.clip, args.fp8_t5
79
+
80
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
81
+ prompts = load_prompts(sample_prompts)
82
+
83
+ def encode_for_text_encoder(text_encoder):
84
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
85
+ # with accelerator.autocast(), torch.no_grad(): # this causes NaN if dit_dtype is fp16
86
+ t5_dtype = config.t5_dtype
87
+ with torch.amp.autocast(device_type=device.type, dtype=t5_dtype), torch.no_grad():
88
+ for prompt_dict in prompts:
89
+ if "negative_prompt" not in prompt_dict:
90
+ prompt_dict["negative_prompt"] = self.config["sample_neg_prompt"]
91
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", None)]:
92
+ if p is None:
93
+ continue
94
+ if p not in sample_prompts_te_outputs:
95
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
96
+
97
+ prompt_outputs = text_encoder([p], device)
98
+ sample_prompts_te_outputs[p] = prompt_outputs
99
+
100
+ return sample_prompts_te_outputs
101
+
102
+ # Load Text Encoder 1 and encode
103
+ logger.info(f"loading T5: {t5_path}")
104
+ t5 = T5EncoderModel(text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=t5_path, fp8=fp8_t5)
105
+
106
+ logger.info("encoding with Text Encoder 1")
107
+ te_outputs_1 = encode_for_text_encoder(t5)
108
+ del t5
109
+
110
+ # load CLIP and encode image (for I2V training)
111
+ # Note: VAE encoding is done in do_inference() for I2V training, because we have VAE in the pipeline. Control video is also done in do_inference()
112
+ sample_prompts_image_embs = {}
113
+ for prompt_dict in prompts:
114
+ if prompt_dict.get("image_path", None) is not None and self.i2v_training:
115
+ sample_prompts_image_embs[prompt_dict["image_path"]] = None # this will be replaced with CLIP context
116
+
117
+ if len(sample_prompts_image_embs) > 0:
118
+ logger.info(f"loading CLIP: {clip_path}")
119
+ assert clip_path is not None, "CLIP path is required for I2V training / I2V学習にはCLIPのパスが必要です"
120
+ clip = CLIPModel(dtype=config.clip_dtype, device=device, weight_path=clip_path)
121
+ clip.model.to(device)
122
+
123
+ logger.info(f"Encoding image to CLIP context")
124
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
125
+ for image_path in sample_prompts_image_embs:
126
+ logger.info(f"Encoding image: {image_path}")
127
+ img = Image.open(image_path).convert("RGB")
128
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) # -1 to 1
129
+ clip_context = clip.visual([img[:, None, :, :]])
130
+ sample_prompts_image_embs[image_path] = clip_context
131
+
132
+ del clip
133
+ clean_memory_on_device(device)
134
+
135
+ # prepare sample parameters
136
+ sample_parameters = []
137
+ for prompt_dict in prompts:
138
+ prompt_dict_copy = prompt_dict.copy()
139
+
140
+ p = prompt_dict.get("prompt", "")
141
+ prompt_dict_copy["t5_embeds"] = te_outputs_1[p][0]
142
+
143
+ p = prompt_dict.get("negative_prompt", None)
144
+ if p is not None:
145
+ prompt_dict_copy["negative_t5_embeds"] = te_outputs_1[p][0]
146
+
147
+ p = prompt_dict.get("image_path", None)
148
+ if p is not None and self.i2v_training:
149
+ prompt_dict_copy["clip_embeds"] = sample_prompts_image_embs[p]
150
+
151
+ sample_parameters.append(prompt_dict_copy)
152
+
153
+ clean_memory_on_device(accelerator.device)
154
+
155
+ return sample_parameters
156
+
157
+ def do_inference(
158
+ self,
159
+ accelerator,
160
+ args,
161
+ sample_parameter,
162
+ vae,
163
+ dit_dtype,
164
+ transformer,
165
+ discrete_flow_shift,
166
+ sample_steps,
167
+ width,
168
+ height,
169
+ frame_count,
170
+ generator,
171
+ do_classifier_free_guidance,
172
+ guidance_scale,
173
+ cfg_scale,
174
+ image_path=None,
175
+ control_video_path=None,
176
+ ):
177
+ """architecture dependent inference"""
178
+ model: WanModel = transformer
179
+ device = accelerator.device
180
+ if cfg_scale is None:
181
+ cfg_scale = 5.0
182
+ do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
183
+
184
+ # Calculate latent video length based on VAE version
185
+ latent_video_length = (frame_count - 1) // self.config["vae_stride"][0] + 1
186
+
187
+ # Get embeddings
188
+ context = sample_parameter["t5_embeds"].to(device=device)
189
+ if do_classifier_free_guidance:
190
+ context_null = sample_parameter["negative_t5_embeds"].to(device=device)
191
+ else:
192
+ context_null = None
193
+
194
+ num_channels_latents = 16 # model.in_dim
195
+ vae_scale_factor = self.config["vae_stride"][1]
196
+
197
+ # Initialize latents
198
+ lat_h = height // vae_scale_factor
199
+ lat_w = width // vae_scale_factor
200
+ shape_or_frame = (1, num_channels_latents, 1, lat_h, lat_w)
201
+ latents = []
202
+ for _ in range(latent_video_length):
203
+ latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=torch.float32))
204
+ latents = torch.cat(latents, dim=2)
205
+
206
+ image_latents = None
207
+ if self.i2v_training or self.control_training:
208
+ # Move VAE to the appropriate device for sampling: consider to cache image latents in CPU in advance
209
+ vae.to(device)
210
+ vae.eval()
211
+
212
+ if self.i2v_training:
213
+ image = Image.open(image_path)
214
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
215
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).float() # C, 1, H, W
216
+ image = image / 127.5 - 1 # -1 to 1
217
+
218
+ # Create mask for the required number of frames
219
+ msk = torch.ones(1, frame_count, lat_h, lat_w, device=device)
220
+ msk[:, 1:] = 0
221
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
222
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
223
+ msk = msk.transpose(1, 2) # B, C, T, H, W
224
+
225
+ with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
226
+ # Zero padding for the required number of frames only
227
+ padding_frames = frame_count - 1 # The first frame is the input image
228
+ image = torch.concat([image, torch.zeros(3, padding_frames, height, width)], dim=1).to(device=device)
229
+ y = vae.encode([image])[0]
230
+
231
+ y = y[:, :latent_video_length] # may be not needed
232
+ y = y.unsqueeze(0) # add batch dim
233
+ image_latents = torch.concat([msk, y], dim=1)
234
+
235
+ if self.control_training:
236
+ # Control video
237
+ video = load_video(control_video_path, 0, frame_count, bucket_reso=(width, height)) # list of frames
238
+ video = np.stack(video, axis=0) # F, H, W, C
239
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # C, F, H, W
240
+ video = video / 127.5 - 1 # -1 to 1
241
+ video = video.to(device=device)
242
+
243
+ with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
244
+ control_latents = vae.encode([video])[0]
245
+ control_latents = control_latents[:, :latent_video_length]
246
+ control_latents = control_latents.unsqueeze(0) # add batch dim
247
+
248
+ # We supports Wan2.1-Fun-Control only
249
+ if image_latents is not None:
250
+ image_latents = image_latents[:, 4:] # remove mask for Wan2.1-Fun-Control
251
+ image_latents[:, :, 1:] = 0 # remove except the first frame
252
+ else:
253
+ image_latents = torch.zeros_like(control_latents) # B, C, F, H, W
254
+
255
+ image_latents = torch.concat([control_latents, image_latents], dim=1) # B, C, F, H, W
256
+
257
+ vae.to("cpu")
258
+ clean_memory_on_device(device)
259
+
260
+ # use the default value for num_train_timesteps (1000)
261
+ scheduler = FlowUniPCMultistepScheduler(shift=1, use_dynamic_shifting=False)
262
+ scheduler.set_timesteps(sample_steps, device=device, shift=discrete_flow_shift)
263
+ timesteps = scheduler.timesteps
264
+
265
+ # Generate noise for the required number of frames only
266
+ noise = torch.randn(16, latent_video_length, lat_h, lat_w, dtype=torch.float32, generator=generator, device=device).to(
267
+ "cpu"
268
+ )
269
+
270
+ # prepare the model input
271
+ max_seq_len = latent_video_length * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
272
+ arg_c = {"context": [context], "seq_len": max_seq_len}
273
+ arg_null = {"context": [context_null], "seq_len": max_seq_len}
274
+
275
+ if self.i2v_training:
276
+ arg_c["clip_fea"] = sample_parameter["clip_embeds"].to(device=device, dtype=dit_dtype)
277
+ arg_null["clip_fea"] = arg_c["clip_fea"]
278
+ if self.i2v_training or self.control_training:
279
+ arg_c["y"] = image_latents
280
+ arg_null["y"] = image_latents
281
+
282
+ # Wrap the inner loop with tqdm to track progress over timesteps
283
+ prompt_idx = sample_parameter.get("enum", 0)
284
+ latent = noise
285
+ with torch.no_grad():
286
+ for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")):
287
+ latent_model_input = [latent.to(device=device)]
288
+ timestep = t.unsqueeze(0)
289
+
290
+ with accelerator.autocast():
291
+ noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to("cpu")
292
+ if do_classifier_free_guidance:
293
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to("cpu")
294
+ else:
295
+ noise_pred_uncond = None
296
+
297
+ if do_classifier_free_guidance:
298
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
299
+ else:
300
+ noise_pred = noise_pred_cond
301
+
302
+ temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator)[0]
303
+ latent = temp_x0.squeeze(0)
304
+
305
+ # Move VAE to the appropriate device for sampling
306
+ vae.to(device)
307
+ vae.eval()
308
+
309
+ # Decode latents to video
310
+ logger.info(f"Decoding video from latents: {latent.shape}")
311
+ latent = latent.unsqueeze(0) # add batch dim
312
+ latent = latent.to(device=device)
313
+
314
+ with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
315
+ video = vae.decode(latent)[0] # vae returns list
316
+ video = video.unsqueeze(0) # add batch dim
317
+ del latent
318
+
319
+ logger.info(f"Decoding complete")
320
+ video = video.to(torch.float32).cpu()
321
+ video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
322
+
323
+ vae.to("cpu")
324
+ clean_memory_on_device(device)
325
+
326
+ return video
327
+
328
+ def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
329
+ vae_path = args.vae
330
+
331
+ logger.info(f"Loading VAE model from {vae_path}")
332
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
333
+ vae = WanVAE(vae_path=vae_path, device="cpu", dtype=vae_dtype, cache_device=cache_device)
334
+ return vae
335
+
336
+ def load_transformer(
337
+ self,
338
+ accelerator: Accelerator,
339
+ args: argparse.Namespace,
340
+ dit_path: str,
341
+ attn_mode: str,
342
+ split_attn: bool,
343
+ loading_device: str,
344
+ dit_weight_dtype: Optional[torch.dtype],
345
+ ):
346
+ model = load_wan_model(
347
+ self.config, accelerator.device, dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.fp8_scaled
348
+ )
349
+ return model
350
+
351
+ def scale_shift_latents(self, latents):
352
+ return latents
353
+
354
+ def call_dit(
355
+ self,
356
+ args: argparse.Namespace,
357
+ accelerator: Accelerator,
358
+ transformer,
359
+ latents: torch.Tensor,
360
+ batch: dict[str, torch.Tensor],
361
+ noise: torch.Tensor,
362
+ noisy_model_input: torch.Tensor,
363
+ timesteps: torch.Tensor,
364
+ network_dtype: torch.dtype,
365
+ ):
366
+ model: WanModel = transformer
367
+
368
+ # I2V training and Control training
369
+ image_latents = None
370
+ clip_fea = None
371
+ if self.i2v_training:
372
+ image_latents = batch["latents_image"]
373
+ image_latents = image_latents.to(device=accelerator.device, dtype=network_dtype)
374
+ clip_fea = batch["clip"]
375
+ clip_fea = clip_fea.to(device=accelerator.device, dtype=network_dtype)
376
+ if self.control_training:
377
+ control_latents = batch["latents_control"]
378
+ control_latents = control_latents.to(device=accelerator.device, dtype=network_dtype)
379
+ if image_latents is not None:
380
+ image_latents = image_latents[:, 4:] # remove mask for Wan2.1-Fun-Control
381
+ image_latents[:, :, 1:] = 0 # remove except the first frame
382
+ else:
383
+ image_latents = torch.zeros_like(control_latents) # B, C, F, H, W
384
+ image_latents = torch.concat([control_latents, image_latents], dim=1) # B, C, F, H, W
385
+ control_latents = None
386
+
387
+ context = [t.to(device=accelerator.device, dtype=network_dtype) for t in batch["t5"]]
388
+
389
+ # ensure the hidden state will require grad
390
+ if args.gradient_checkpointing:
391
+ noisy_model_input.requires_grad_(True)
392
+ for t in context:
393
+ t.requires_grad_(True)
394
+ if image_latents is not None:
395
+ image_latents.requires_grad_(True)
396
+ if clip_fea is not None:
397
+ clip_fea.requires_grad_(True)
398
+
399
+ # call DiT
400
+ lat_f, lat_h, lat_w = latents.shape[2:5]
401
+ seq_len = lat_f * lat_h * lat_w // (self.config.patch_size[0] * self.config.patch_size[1] * self.config.patch_size[2])
402
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
403
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
404
+ with accelerator.autocast():
405
+ model_pred = model(noisy_model_input, t=timesteps, context=context, clip_fea=clip_fea, seq_len=seq_len, y=image_latents)
406
+ model_pred = torch.stack(model_pred, dim=0) # list to tensor
407
+
408
+ # flow matching loss
409
+ target = noise - latents
410
+
411
+ return model_pred, target
412
+
413
+ # endregion model specific
414
+
415
+
416
+ def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
417
+ """Wan2.1 specific parser setup"""
418
+ parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
419
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
420
+ parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
421
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
422
+ parser.add_argument(
423
+ "--clip",
424
+ type=str,
425
+ default=None,
426
+ help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
427
+ )
428
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
429
+ return parser
430
+
431
+
432
+ if __name__ == "__main__":
433
+ parser = setup_parser_common()
434
+ parser = wan_setup_parser(parser)
435
+
436
+ args = parser.parse_args()
437
+ args = read_config_from_file(args, parser)
438
+
439
+ args.dit_dtype = None # automatically detected
440
+ if args.vae_dtype is None:
441
+ args.vae_dtype = "bfloat16" # make bfloat16 as default for VAE
442
+
443
+ trainer = WanNetworkTrainer()
444
+ trainer.train(args)