svjack commited on
Commit
4d42572
·
verified ·
1 Parent(s): 71e89ad

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. cache_latents.py +61 -3
  3. cache_text_encoder_outputs.py +3 -3
  4. convert_lora.py +2 -0
  5. dataset/config_utils.py +30 -2
  6. dataset/dataset_config.md +231 -80
  7. dataset/image_video_dataset.py +561 -72
  8. docs/advanced_config.md +166 -1
  9. docs/framepack.md +607 -0
  10. docs/framepack_1f.md +359 -0
  11. docs/kisekaeichi_ref.png +3 -0
  12. docs/kisekaeichi_ref_mask.png +0 -0
  13. docs/kisekaeichi_result.png +3 -0
  14. docs/kisekaeichi_start.png +3 -0
  15. docs/kisekaeichi_start_mask.png +0 -0
  16. docs/sampling_during_training.md +18 -10
  17. docs/wan.md +302 -12
  18. fpack_cache_latents.py +524 -0
  19. fpack_cache_text_encoder_outputs.py +110 -0
  20. fpack_generate_video.py +1832 -0
  21. fpack_train_network.py +617 -0
  22. frame_pack/__init__.py +0 -0
  23. frame_pack/bucket_tools.py +30 -0
  24. frame_pack/clip_vision.py +14 -0
  25. frame_pack/framepack_utils.py +273 -0
  26. frame_pack/hunyuan.py +134 -0
  27. frame_pack/hunyuan_video_packed.py +2038 -0
  28. frame_pack/k_diffusion_hunyuan.py +128 -0
  29. frame_pack/uni_pc_fm.py +142 -0
  30. frame_pack/utils.py +617 -0
  31. frame_pack/wrapper.py +51 -0
  32. hunyuan_model/fp8_optimization.py +39 -0
  33. hv_generate_video.py +52 -27
  34. hv_train_network.py +110 -11
  35. merge_lora.py +1 -1
  36. modules/fp8_optimization_utils.py +356 -0
  37. networks/lora.py +0 -1
  38. networks/lora_framepack.py +65 -0
  39. pyproject.toml +4 -2
  40. requirements.txt +4 -4
  41. utils/safetensors_utils.py +31 -1
  42. utils/sai_model_spec.py +10 -2
  43. utils/train_utils.py +1 -0
  44. wan/__init__.py +0 -2
  45. wan/configs/__init__.py +46 -19
  46. wan/configs/shared_config.py +1 -0
  47. wan/configs/wan_i2v_14B.py +11 -8
  48. wan/configs/wan_t2v_14B.py +8 -5
  49. wan/configs/wan_t2v_1_3B.py +8 -5
  50. wan/modules/model.py +135 -5
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/kisekaeichi_ref.png filter=lfs diff=lfs merge=lfs -text
37
+ docs/kisekaeichi_result.png filter=lfs diff=lfs merge=lfs -text
38
+ docs/kisekaeichi_start.png filter=lfs diff=lfs merge=lfs -text
cache_latents.py CHANGED
@@ -86,10 +86,65 @@ def show_console(
86
  return ord(k) if k else ord(" ")
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def show_datasets(
90
- datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
 
 
 
 
 
91
  ):
92
- print(f"d: next dataset, q: quit")
 
93
 
94
  num_workers = max(1, os.cpu_count() - 1)
95
  for i, dataset in enumerate(datasets):
@@ -110,6 +165,9 @@ def show_datasets(
110
  num_images_to_show -= 1
111
  if num_images_to_show == 0:
112
  k = ord("d") # next dataset
 
 
 
113
 
114
  if k == ord("q"):
115
  return
@@ -246,7 +304,7 @@ def setup_parser_common() -> argparse.ArgumentParser:
246
  parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
247
  parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
248
  parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
249
- parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
250
  parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
251
  parser.add_argument(
252
  "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
 
86
  return ord(k) if k else ord(" ")
87
 
88
 
89
+ def save_video(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], cache_path: str, fps: int = 24):
90
+ import av
91
+
92
+ directory = os.path.dirname(cache_path)
93
+ if not os.path.exists(directory):
94
+ os.makedirs(directory)
95
+
96
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image):
97
+ # save image
98
+ image_path = cache_path.replace(".safetensors", ".jpg")
99
+ img = image if isinstance(image, Image.Image) else Image.fromarray(image)
100
+ img.save(image_path)
101
+ print(f"Saved image: {image_path}")
102
+ else:
103
+ imgs = image
104
+ print(f"Number of images: {len(imgs)}")
105
+ # save video
106
+ video_path = cache_path.replace(".safetensors", ".mp4")
107
+ height, width = imgs[0].shape[0:2]
108
+
109
+ # create output container
110
+ container = av.open(video_path, mode="w")
111
+
112
+ # create video stream
113
+ codec = "libx264"
114
+ pixel_format = "yuv420p"
115
+ stream = container.add_stream(codec, rate=fps)
116
+ stream.width = width
117
+ stream.height = height
118
+ stream.pix_fmt = pixel_format
119
+ stream.bit_rate = 1000000 # 1Mbit/s for preview quality
120
+
121
+ for frame_img in imgs:
122
+ if isinstance(frame_img, Image.Image):
123
+ frame = av.VideoFrame.from_image(frame_img)
124
+ else:
125
+ frame = av.VideoFrame.from_ndarray(frame_img, format="rgb24")
126
+ packets = stream.encode(frame)
127
+ for packet in packets:
128
+ container.mux(packet)
129
+
130
+ for packet in stream.encode():
131
+ container.mux(packet)
132
+
133
+ container.close()
134
+
135
+ print(f"Saved video: {video_path}")
136
+
137
+
138
  def show_datasets(
139
+ datasets: list[BaseDataset],
140
+ debug_mode: str,
141
+ console_width: int,
142
+ console_back: str,
143
+ console_num_images: Optional[int],
144
+ fps: int = 24,
145
  ):
146
+ if debug_mode != "video":
147
+ print(f"d: next dataset, q: quit")
148
 
149
  num_workers = max(1, os.cpu_count() - 1)
150
  for i, dataset in enumerate(datasets):
 
165
  num_images_to_show -= 1
166
  if num_images_to_show == 0:
167
  k = ord("d") # next dataset
168
+ elif debug_mode == "video":
169
+ save_video(item_info.content, item_info.latent_cache_path, fps)
170
+ k = None # save next video
171
 
172
  if k == ord("q"):
173
  return
 
304
  parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
305
  parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
306
  parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
307
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console", "video"], help="debug mode")
308
  parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
309
  parser.add_argument(
310
  "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
cache_text_encoder_outputs.py CHANGED
@@ -100,14 +100,14 @@ def process_text_encoder_batches(
100
 
101
 
102
  def post_process_cache_files(
103
- datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set]
104
  ):
105
  for i, dataset in enumerate(datasets):
106
  all_cache_files = all_cache_files_for_dataset[i]
107
  all_cache_paths = all_cache_paths_for_dataset[i]
108
  for cache_file in all_cache_files:
109
  if cache_file not in all_cache_paths:
110
- if args.keep_cache:
111
  logger.info(f"Keep cache file not in the dataset: {cache_file}")
112
  else:
113
  os.remove(cache_file)
@@ -181,7 +181,7 @@ def main(args):
181
  del text_encoder_2
182
 
183
  # remove cache files not in dataset
184
- post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
185
 
186
 
187
  def setup_parser_common():
 
100
 
101
 
102
  def post_process_cache_files(
103
+ datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set], keep_cache: bool
104
  ):
105
  for i, dataset in enumerate(datasets):
106
  all_cache_files = all_cache_files_for_dataset[i]
107
  all_cache_paths = all_cache_paths_for_dataset[i]
108
  for cache_file in all_cache_files:
109
  if cache_file not in all_cache_paths:
110
+ if keep_cache:
111
  logger.info(f"Keep cache file not in the dataset: {cache_file}")
112
  else:
113
  os.remove(cache_file)
 
181
  del text_encoder_2
182
 
183
  # remove cache files not in dataset
184
+ post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
185
 
186
 
187
  def setup_parser_common():
convert_lora.py CHANGED
@@ -65,6 +65,8 @@ def convert_to_diffusers(prefix, weights_sd):
65
  # Wan2.1 lora name to module name: ugly but works
66
  module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
67
  module_name = module_name.replace("self.attn", "self_attn") # fix self attn
 
 
68
  else:
69
  # HunyuanVideo lora name to module name: ugly but works
70
  module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
 
65
  # Wan2.1 lora name to module name: ugly but works
66
  module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
67
  module_name = module_name.replace("self.attn", "self_attn") # fix self attn
68
+ module_name = module_name.replace("k.img", "k_img") # fix k img
69
+ module_name = module_name.replace("v.img", "v_img") # fix v img
70
  else:
71
  # HunyuanVideo lora name to module name: ugly but works
72
  module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
dataset/config_utils.py CHANGED
@@ -41,16 +41,29 @@ class BaseDatasetParams:
41
  class ImageDatasetParams(BaseDatasetParams):
42
  image_directory: Optional[str] = None
43
  image_jsonl_file: Optional[str] = None
 
 
 
 
 
 
 
44
 
45
 
46
  @dataclass
47
  class VideoDatasetParams(BaseDatasetParams):
48
  video_directory: Optional[str] = None
49
  video_jsonl_file: Optional[str] = None
 
50
  target_frames: Sequence[int] = (1,)
51
  frame_extraction: Optional[str] = "head"
52
  frame_stride: Optional[int] = 1
53
  frame_sample: Optional[int] = 1
 
 
 
 
 
54
 
55
 
56
  @dataclass
@@ -99,15 +112,23 @@ class ConfigSanitizer:
99
  "image_directory": str,
100
  "image_jsonl_file": str,
101
  "cache_directory": str,
 
 
 
 
 
102
  }
103
  VIDEO_DATASET_DISTINCT_SCHEMA = {
104
  "video_directory": str,
105
  "video_jsonl_file": str,
 
106
  "target_frames": [int],
107
  "frame_extraction": str,
108
  "frame_stride": int,
109
  "frame_sample": int,
 
110
  "cache_directory": str,
 
111
  }
112
 
113
  # options handled by argparse but not handled by user config
@@ -126,7 +147,7 @@ class ConfigSanitizer:
126
  )
127
 
128
  def validate_flex_dataset(dataset_config: dict):
129
- if "target_frames" in dataset_config:
130
  return Schema(self.video_dataset_schema)(dataset_config)
131
  else:
132
  return Schema(self.image_dataset_schema)(dataset_config)
@@ -194,7 +215,7 @@ class BlueprintGenerator:
194
 
195
  dataset_blueprints = []
196
  for dataset_config in sanitized_user_config.get("datasets", []):
197
- is_image_dataset = "target_frames" not in dataset_config
198
  if is_image_dataset:
199
  dataset_params_klass = ImageDatasetParams
200
  else:
@@ -277,6 +298,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
277
  f"""\
278
  image_directory: "{dataset.image_directory}"
279
  image_jsonl_file: "{dataset.image_jsonl_file}"
 
 
 
 
280
  \n"""
281
  ),
282
  " ",
@@ -287,10 +312,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
287
  f"""\
288
  video_directory: "{dataset.video_directory}"
289
  video_jsonl_file: "{dataset.video_jsonl_file}"
 
290
  target_frames: {dataset.target_frames}
291
  frame_extraction: {dataset.frame_extraction}
292
  frame_stride: {dataset.frame_stride}
293
  frame_sample: {dataset.frame_sample}
 
 
294
  \n"""
295
  ),
296
  " ",
 
41
  class ImageDatasetParams(BaseDatasetParams):
42
  image_directory: Optional[str] = None
43
  image_jsonl_file: Optional[str] = None
44
+ control_directory: Optional[str] = None
45
+
46
+ # FramePack dependent parameters
47
+ fp_latent_window_size: Optional[int] = 9
48
+ fp_1f_clean_indices: Optional[Sequence[int]] = None
49
+ fp_1f_target_index: Optional[int] = None
50
+ fp_1f_no_post: Optional[bool] = False
51
 
52
 
53
  @dataclass
54
  class VideoDatasetParams(BaseDatasetParams):
55
  video_directory: Optional[str] = None
56
  video_jsonl_file: Optional[str] = None
57
+ control_directory: Optional[str] = None
58
  target_frames: Sequence[int] = (1,)
59
  frame_extraction: Optional[str] = "head"
60
  frame_stride: Optional[int] = 1
61
  frame_sample: Optional[int] = 1
62
+ max_frames: Optional[int] = 129
63
+ source_fps: Optional[float] = None
64
+
65
+ # FramePack dependent parameters
66
+ fp_latent_window_size: Optional[int] = 9
67
 
68
 
69
  @dataclass
 
112
  "image_directory": str,
113
  "image_jsonl_file": str,
114
  "cache_directory": str,
115
+ "control_directory": str,
116
+ "fp_latent_window_size": int,
117
+ "fp_1f_clean_indices": [int],
118
+ "fp_1f_target_index": int,
119
+ "fp_1f_no_post": bool,
120
  }
121
  VIDEO_DATASET_DISTINCT_SCHEMA = {
122
  "video_directory": str,
123
  "video_jsonl_file": str,
124
+ "control_directory": str,
125
  "target_frames": [int],
126
  "frame_extraction": str,
127
  "frame_stride": int,
128
  "frame_sample": int,
129
+ "max_frames": int,
130
  "cache_directory": str,
131
+ "source_fps": float,
132
  }
133
 
134
  # options handled by argparse but not handled by user config
 
147
  )
148
 
149
  def validate_flex_dataset(dataset_config: dict):
150
+ if "video_directory" in dataset_config or "video_jsonl_file" in dataset_config:
151
  return Schema(self.video_dataset_schema)(dataset_config)
152
  else:
153
  return Schema(self.image_dataset_schema)(dataset_config)
 
215
 
216
  dataset_blueprints = []
217
  for dataset_config in sanitized_user_config.get("datasets", []):
218
+ is_image_dataset = "image_directory" in dataset_config or "image_jsonl_file" in dataset_config
219
  if is_image_dataset:
220
  dataset_params_klass = ImageDatasetParams
221
  else:
 
298
  f"""\
299
  image_directory: "{dataset.image_directory}"
300
  image_jsonl_file: "{dataset.image_jsonl_file}"
301
+ fp_latent_window_size: {dataset.fp_latent_window_size}
302
+ fp_1f_clean_indices: {dataset.fp_1f_clean_indices}
303
+ fp_1f_target_index: {dataset.fp_1f_target_index}
304
+ fp_1f_no_post: {dataset.fp_1f_no_post}
305
  \n"""
306
  ),
307
  " ",
 
312
  f"""\
313
  video_directory: "{dataset.video_directory}"
314
  video_jsonl_file: "{dataset.video_jsonl_file}"
315
+ control_directory: "{dataset.control_directory}"
316
  target_frames: {dataset.target_frames}
317
  frame_extraction: {dataset.frame_extraction}
318
  frame_stride: {dataset.frame_stride}
319
  frame_sample: {dataset.frame_sample}
320
+ max_frames: {dataset.max_frames}
321
+ source_fps: {dataset.source_fps}
322
  \n"""
323
  ),
324
  " ",
dataset/dataset_config.md CHANGED
@@ -2,16 +2,13 @@
2
 
3
  ## Dataset Configuration
4
 
5
- <details>
6
- <summary>English</summary>
7
-
8
  Please create a TOML file for dataset configuration.
9
 
10
  Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
11
 
12
  The cache directory must be different for each dataset.
13
- </details>
14
 
 
15
  <details>
16
  <summary>日本語</summary>
17
 
@@ -20,6 +17,8 @@ The cache directory must be different for each dataset.
20
  画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
21
 
22
  キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
 
 
23
  </details>
24
 
25
  ### Sample for Image Dataset with Caption Text Files
@@ -44,15 +43,10 @@ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset.
44
  # other datasets can be added here. each dataset can have different configurations
45
  ```
46
 
47
- <details>
48
- <summary>English</summary>
49
-
50
  `cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
51
 
52
  `num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
53
 
54
- </details>
55
-
56
  <details>
57
  <summary>日本語</summary>
58
 
@@ -108,9 +102,10 @@ metadata jsonl ファイルを使用する場合、caption_extension は必要
108
  ### Sample for Video Dataset with Caption Text Files
109
 
110
  ```toml
111
- # resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample,
112
- # batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
113
- # num_repeats is also available for video dataset, example is not shown here
 
114
 
115
  # general configurations
116
  [general]
@@ -125,14 +120,38 @@ video_directory = "/path/to/video_dir"
125
  cache_directory = "/path/to/cache_directory" # recommended to set cache directory
126
  target_frames = [1, 25, 45]
127
  frame_extraction = "head"
 
 
 
 
 
 
 
128
 
129
  # other datasets can be added here. each dataset can have different configurations
130
  ```
131
 
 
 
 
 
 
 
 
 
132
  <details>
133
  <summary>日本語</summary>
134
 
135
- resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscalegeneral または datasets のどちらかに設定してください。
 
 
 
 
 
 
 
 
 
136
 
137
  他の注意事項は画像データセットと同様です。
138
  </details>
@@ -140,8 +159,11 @@ resolution, caption_extension, target_frames, frame_extraction, frame_stride, fr
140
  ### Sample for Video Dataset with Metadata JSONL File
141
 
142
  ```toml
143
- # resolution, target_frames, frame_extraction, frame_stride, frame_sample,
144
- # batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
 
 
 
145
  # caption_extension is not required for metadata jsonl file
146
  # cache_directory is required for each dataset with metadata jsonl file
147
 
@@ -157,7 +179,7 @@ video_jsonl_file = "/path/to/metadata.jsonl"
157
  target_frames = [1, 25, 45]
158
  frame_extraction = "head"
159
  cache_directory = "/path/to/cache_directory_head"
160
-
161
  # same metadata jsonl file can be used for multiple datasets
162
  [[datasets]]
163
  video_jsonl_file = "/path/to/metadata.jsonl"
@@ -175,28 +197,30 @@ JSONL file format for metadata:
175
  {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
176
  ```
177
 
 
 
178
  <details>
179
  <summary>日本語</summary>
180
-
181
- resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
182
-
183
  metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
184
 
 
 
185
  他の注意事項は今までのデータセットと同様です。
186
  </details>
187
 
188
  ### frame_extraction Options
189
 
190
- <details>
191
- <summary>English</summary>
192
-
193
  - `head`: Extract the first N frames from the video.
194
  - `chunk`: Extract frames by splitting the video into chunks of N frames.
195
  - `slide`: Extract frames from the video with a stride of `frame_stride`.
196
  - `uniform`: Extract `frame_sample` samples uniformly from the video.
 
 
 
 
 
197
 
198
  For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
199
- </details>
200
 
201
  <details>
202
  <summary>日本語</summary>
@@ -205,6 +229,11 @@ For example, consider a video with 40 frames. The following diagrams illustrate
205
  - `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
206
  - `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
207
  - `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
 
 
 
 
 
208
 
209
  例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
210
  </details>
@@ -251,100 +280,209 @@ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
251
  oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
252
  ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
253
  oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
 
 
 
 
 
 
 
254
  ```
255
 
256
- ## Specifications
257
 
258
- ```toml
259
- # general configurations
260
- [general]
261
- resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
262
- caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
263
- batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
264
- num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
265
- enable_bucket = true # optional, default is false. Enable bucketing for datasets
266
- bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
267
 
268
- ### Image Dataset
269
 
270
- # sample image dataset with caption text files
271
- [[datasets]]
272
- image_directory = "/path/to/image_dir"
273
- caption_extension = ".txt" # required for caption text files, if general caption extension is not set
274
- resolution = [960, 544] # required if general resolution is not set
275
- batch_size = 4 # optional, overwrite the default batch size
276
- num_repeats = 1 # optional, overwrite the default num_repeats
277
- enable_bucket = false # optional, overwrite the default bucketing setting
278
- bucket_no_upscale = true # optional, overwrite the default bucketing setting
279
- cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
280
 
281
- # sample image dataset with metadata **jsonl** file
282
- [[datasets]]
283
- image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
284
- resolution = [960, 544] # required if general resolution is not set
285
- cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
286
- # caption_extension is not required for metadata jsonl file
287
- # batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
288
 
289
- ### Video Dataset
290
 
291
- # sample video dataset with caption text files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  [[datasets]]
293
  video_directory = "/path/to/video_dir"
294
- caption_extension = ".txt" # required for caption text files, if general caption extension is not set
295
- resolution = [960, 544] # required if general resolution is not set
 
 
 
296
 
297
- target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
298
 
299
- # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
 
 
 
300
 
301
- frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
302
- frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
303
- frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
304
- # batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
305
 
306
- # sample video dataset with metadata jsonl file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  [[datasets]]
308
- video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
 
309
 
310
- target_frames = [1, 79]
 
311
 
312
- cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
313
- # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
 
 
 
 
 
 
 
 
 
 
 
314
  ```
315
 
316
- <!--
317
- # sample image dataset with lance
 
 
 
 
 
 
 
 
 
 
 
318
  [[datasets]]
319
- image_lance_dataset = "/path/to/lance_dataset"
320
- resolution = [960, 544] # required if general resolution is not set
321
- # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
322
- -->
323
 
324
- The metadata with .json file will be supported in the near future.
 
 
 
 
 
 
 
 
 
 
 
325
 
 
326
 
 
 
 
 
327
 
328
- <!--
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  ```toml
331
  # general configurations
332
  [general]
333
- resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
334
  caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
335
  batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
 
336
  enable_bucket = true # optional, default is false. Enable bucketing for datasets
337
  bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
338
 
 
 
339
  # sample image dataset with caption text files
340
  [[datasets]]
341
  image_directory = "/path/to/image_dir"
342
  caption_extension = ".txt" # required for caption text files, if general caption extension is not set
343
  resolution = [960, 544] # required if general resolution is not set
344
  batch_size = 4 # optional, overwrite the default batch size
 
345
  enable_bucket = false # optional, overwrite the default bucketing setting
346
  bucket_no_upscale = true # optional, overwrite the default bucketing setting
347
  cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
 
348
 
349
  # sample image dataset with metadata **jsonl** file
350
  [[datasets]]
@@ -352,36 +490,49 @@ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and
352
  resolution = [960, 544] # required if general resolution is not set
353
  cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
354
  # caption_extension is not required for metadata jsonl file
355
- # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
 
 
356
 
357
  # sample video dataset with caption text files
358
  [[datasets]]
359
  video_directory = "/path/to/video_dir"
360
  caption_extension = ".txt" # required for caption text files, if general caption extension is not set
361
  resolution = [960, 544] # required if general resolution is not set
 
 
 
 
 
362
  target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
 
 
 
363
  frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
364
  frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
365
  frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
366
- # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
 
367
 
368
  # sample video dataset with metadata jsonl file
369
  [[datasets]]
370
  video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
 
371
  target_frames = [1, 79]
 
372
  cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
373
- # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
374
  ```
375
 
 
376
  # sample image dataset with lance
377
  [[datasets]]
378
  image_lance_dataset = "/path/to/lance_dataset"
379
  resolution = [960, 544] # required if general resolution is not set
380
  # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
 
381
 
382
  The metadata with .json file will be supported in the near future.
383
 
384
 
385
 
386
-
387
- -->
 
2
 
3
  ## Dataset Configuration
4
 
 
 
 
5
  Please create a TOML file for dataset configuration.
6
 
7
  Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
8
 
9
  The cache directory must be different for each dataset.
 
10
 
11
+ Each video is extracted frame by frame without additional processing and used for training. It is recommended to use videos with a frame rate of 24fps for HunyuanVideo, 16fps for Wan2.1 and 30fps for FramePack. You can check the videos that will be trained using `--debug_mode video` when caching latent (see [here](/README.md#latent-caching)).
12
  <details>
13
  <summary>日本語</summary>
14
 
 
17
  画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
18
 
19
  キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
20
+
21
+ 動画は追加のプロセスなしでフレームごとに抽出され、学習に用いられます。そのため、HunyuanVideoは24fps、Wan2.1は16fps、FramePackは30fpsのフレームレートの動画を使用することをお勧めします。latentキャッシュ時の`--debug_mode video`を使用すると、学習される動画を確認できます([こちら](/README.ja.md#latentの事前キャッシュ)を参照)。
22
  </details>
23
 
24
  ### Sample for Image Dataset with Caption Text Files
 
43
  # other datasets can be added here. each dataset can have different configurations
44
  ```
45
 
 
 
 
46
  `cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
47
 
48
  `num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
49
 
 
 
50
  <details>
51
  <summary>日本語</summary>
52
 
 
102
  ### Sample for Video Dataset with Caption Text Files
103
 
104
  ```toml
105
+ # Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
106
+ # can be set in either general or datasets sections
107
+ # Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
108
+ # must be set in each datasets section
109
 
110
  # general configurations
111
  [general]
 
120
  cache_directory = "/path/to/cache_directory" # recommended to set cache directory
121
  target_frames = [1, 25, 45]
122
  frame_extraction = "head"
123
+ source_fps = 30.0 # optional, source fps for videos in the directory, decimal number
124
+
125
+ [[datasets]]
126
+ video_directory = "/path/to/video_dir2"
127
+ cache_directory = "/path/to/cache_directory2" # recommended to set cache directory
128
+ frame_extraction = "full"
129
+ max_frames = 45
130
 
131
  # other datasets can be added here. each dataset can have different configurations
132
  ```
133
 
134
+ __In HunyuanVideo and Wan2.1, the number of `target_frames` must be "N\*4+1" (N=0,1,2,...).__ Otherwise, it will be truncated to the nearest "N*4+1".
135
+
136
+ In FramePack, it is recommended to set `frame_extraction` to `full` and `max_frames` to a sufficiently large value, as it can handle longer videos. However, if the video is too long, an Out of Memory error may occur during VAE encoding. The videos in FramePack are trimmed to "N * latent_window_size * 4 + 1" frames (for example, 37, 73, 109... if `latent_window_size` is 9).
137
+
138
+ If the `source_fps` is specified, the videos in the directory are considered to be at this frame rate, and some frames will be skipped to match the model's frame rate (24 for HunyuanVideo and 16 for Wan2.1). __The value must be a decimal number, for example, `30.0` instead of `30`.__ The skipping is done automatically and does not consider the content of the images. Please check if the converted data is correct using `--debug_mode video`.
139
+
140
+ If `source_fps` is not specified (default), all frames of the video will be used regardless of the video's frame rate.
141
+
142
  <details>
143
  <summary>日本語</summary>
144
 
145
+ 共通パラメータ(resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)は、generalまたはdatasetsのいずれかに設定できます。
146
+ 動画固有のパラメータ(target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)は、各datasetsセクションに設定する必要があります。
147
+
148
+ __HunyuanVideoおよびWan2.1では、target_framesの数値は「N\*4+1」である必要があります。__ これ以外の値の場合は、最も近いN\*4+1の値に切り捨てられます。
149
+
150
+ FramePackでも同様ですが、FramePackでは動画が長くても学習可能なため、 `frame_extraction`に`full` を指定し、`max_frames`を十分に大きな値に設定することをお勧めします。ただし、あまりにも長すぎるとVAEのencodeでOut of Memoryエラーが発生する可能性があります。FramePackの動画は、「N * latent_window_size * 4 + 1」フレームにトリミングされます(latent_window_sizeが9の場合、37、73、109……)。
151
+
152
+ `source_fps`を指定した場合、ディレクトリ内の動画をこのフレームレートとみなして、モデルのフレームレートにあうようにいくつかのフレームをスキップします(HunyuanVideoは24、Wan2.1は16)。__小数点を含む数値で指定してください。__ 例:`30`ではなく`30.0`。スキップは機械的に行われ、画像の内容は考慮しません。変換後のデータが正しいか、`--debug_mode video`で確認してください。
153
+
154
+ `source_fps`を指定しない場合、動画のフレームは(動画自体のフレームレートに関係なく)すべて使用されます。
155
 
156
  他の注意事項は画像データセットと同様です。
157
  </details>
 
159
  ### Sample for Video Dataset with Metadata JSONL File
160
 
161
  ```toml
162
+ # Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
163
+ # can be set in either general or datasets sections
164
+ # Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
165
+ # must be set in each datasets section
166
+
167
  # caption_extension is not required for metadata jsonl file
168
  # cache_directory is required for each dataset with metadata jsonl file
169
 
 
179
  target_frames = [1, 25, 45]
180
  frame_extraction = "head"
181
  cache_directory = "/path/to/cache_directory_head"
182
+ source_fps = 30.0 # optional, source fps for videos in the jsonl file
183
  # same metadata jsonl file can be used for multiple datasets
184
  [[datasets]]
185
  video_jsonl_file = "/path/to/metadata.jsonl"
 
197
  {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
198
  ```
199
 
200
+ `video_path` can be a directory containing multiple images.
201
+
202
  <details>
203
  <summary>日本語</summary>
 
 
 
204
  metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
205
 
206
+ `video_path`は、複数の画像を含むディレクトリのパスでも構いません。
207
+
208
  他の注意事項は今までのデータセットと同様です。
209
  </details>
210
 
211
  ### frame_extraction Options
212
 
 
 
 
213
  - `head`: Extract the first N frames from the video.
214
  - `chunk`: Extract frames by splitting the video into chunks of N frames.
215
  - `slide`: Extract frames from the video with a stride of `frame_stride`.
216
  - `uniform`: Extract `frame_sample` samples uniformly from the video.
217
+ - `full`: Extract all frames from the video.
218
+
219
+ In the case of `full`, the entire video is used, but it is trimmed to "N*4+1" frames. It is also trimmed to the `max_frames` if it exceeds that value. To avoid Out of Memory errors, please set `max_frames`.
220
+
221
+ The frame extraction methods other than `full` are recommended when the video contains repeated actions. `full` is recommended when each video represents a single complete motion.
222
 
223
  For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
 
224
 
225
  <details>
226
  <summary>日本語</summary>
 
229
  - `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
230
  - `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
231
  - `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
232
+ - `full`: 動画から全てのフレームを抽出します。
233
+
234
+ `full`の場合、各動画の全体を用いますが、「N*4+1」のフレーム数にトリミングされます。また`max_frames`を超える場合もその値にトリミングされます。Out of Memoryエラーを避けるために、`max_frames`を設定してください。
235
+
236
+ `full`以外の抽出方法は、動画が特定の動作を繰り返している場合にお勧めします。`full`はそれぞれの動画がひとつの完結したモーションの場合にお勧めします。
237
 
238
  例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
239
  </details>
 
280
  oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
281
  ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
282
  oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
283
+
284
+ Three Original Videos, 20, 25, 35 frames: x = frame, o = no frame
285
+
286
+ full, max_frames = 31 -> extract all frames (trimmed to the maximum length):
287
+ video1: xxxxxxxxxxxxxxxxx (trimmed to 17 frames)
288
+ video2: xxxxxxxxxxxxxxxxxxxxxxxxx (25 frames)
289
+ video3: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx (trimmed to 31 frames)
290
  ```
291
 
292
+ ### Sample for Image Dataset with Control Images
293
 
294
+ The dataset with control images. This is used for training the one frame training for FramePack.
 
 
 
 
 
 
 
 
295
 
296
+ The dataset configuration with caption text files is similar to the image dataset, but with an additional `control_directory` parameter.
297
 
298
+ The control images are used from the `control_directory` with the same filename (or different extension) as the image, for example, `image_dir/image1.jpg` and `control_dir/image1.png`. The images in `image_directory` should be the target images (the images to be generated during inference, the changed images). The `control_directory` should contain the starting images for inference. The captions should be stored in `image_directory`.
 
 
 
 
 
 
 
 
 
299
 
300
+ If multiple control images are specified, the filenames of the control images should be numbered (excluding the extension). For example, specify `image_dir/image1.jpg` and `control_dir/image1_0.png`, `control_dir/image1_1.png`. You can also specify the numbers with four digits, such as `image1_0000.png`, `image1_0001.png`.
 
 
 
 
 
 
301
 
302
+ The metadata JSONL file format is the same as the image dataset, but with an additional `control_path` parameter.
303
 
304
+ ```json
305
+ {"image_path": "/path/to/image1.jpg", "control_path": "/path/to/control1.png", "caption": "A caption for image1"}
306
+ {"image_path": "/path/to/image2.jpg", "control_path": "/path/to/control2.png", "caption": "A caption for image2"}
307
+
308
+ If multiple control images are specified, the attribute names should be `control_path_0`, `control_path_1`, etc.
309
+
310
+ ```json
311
+ {"image_path": "/path/to/image1.jpg", "control_path_0": "/path/to/control1_0.png", "control_path_1": "/path/to/control1_1.png", "caption": "A caption for image1"}
312
+ {"image_path": "/path/to/image2.jpg", "control_path_0": "/path/to/control2_0.png", "control_path_1": "/path/to/control2_1.png", "caption": "A caption for image2"}
313
+ ```
314
+
315
+ The control images can also have an alpha channel. In this case, the alpha channel of the image is used as a mask for the latent.
316
+
317
+ <details>
318
+ <summary>日本語</summary>
319
+
320
+ 制御画像を持つデータセットです。現時点ではFramePackの単一フレーム学習に使用します。
321
+
322
+ キャプションファイルを用いる場合は`control_directory`を追加で指定してください。制御画像は、画像と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある画像が使用されます(例:`image_dir/image1.jpg`と`control_dir/image1.png`)。`image_directory`の画像は学習対象の画像(推論時に生成する画像、変化後の画像)としてください。`control_directory`には推論時の開始画像を格納してください。キャプションは`image_directory`へ格納してください。
323
+
324
+ 複数枚の制御画像が指定可能です。この場合、制御画像のファイル名(拡張子を除く)へ数字を付与してください。例えば、`image_dir/image1.jpg`と`control_dir/image1_0.png`, `control_dir/image1_1.png`のように指定します。`image1_0000.png`, `image1_0001.png`のように数字を4桁で指定することもできます。
325
+
326
+ メタデータJSONLファイルを使用する場合は、`control_path`を追加してください。複数枚の制御画像を指定する場合は、`control_path_0`, `control_path_1`のように数字を付与してください。
327
+
328
+ 制御画像はアルファチャンネルを持つこともできます。この場合、画像のアルファチャンネルはlatentへのマスクとして使用されます。
329
+
330
+ </details>
331
+
332
+ ### Sample for Video Dataset with Control Images
333
+
334
+ The dataset with control videos is used for training ControlNet models.
335
+
336
+ The dataset configuration with caption text files is similar to the video dataset, but with an additional `control_directory` parameter.
337
+
338
+ The control video for a video is used from the `control_directory` with the same filename (or different extension) as the video, for example, `video_dir/video1.mp4` and `control_dir/video1.mp4` or `control_dir/video1.mov`. The control video can also be a directory without an extension, for example, `video_dir/video1.mp4` and `control_dir/video1`.
339
+
340
+ ```toml
341
  [[datasets]]
342
  video_directory = "/path/to/video_dir"
343
+ control_directory = "/path/to/control_dir" # required for dataset with control videos
344
+ cache_directory = "/path/to/cache_directory" # recommended to set cache directory
345
+ target_frames = [1, 25, 45]
346
+ frame_extraction = "head"
347
+ ```
348
 
349
+ The dataset configuration with metadata JSONL file is same as the video dataset, but metadata JSONL file must include the control video paths. The control video path can be a directory containing multiple images.
350
 
351
+ ```json
352
+ {"video_path": "/path/to/video1.mp4", "control_path": "/path/to/control1.mp4", "caption": "A caption for video1"}
353
+ {"video_path": "/path/to/video2.mp4", "control_path": "/path/to/control2.mp4", "caption": "A caption for video2"}
354
+ ```
355
 
356
+ <details>
357
+ <summary>日本語</summary>
 
 
358
 
359
+ 制御動画を持つデータセットです。ControlNetモデルの学習に使用します。
360
+
361
+ キャプションを用いる場合のデータセット設定は動画データセットと似ていますが、`control_directory`パラメータが追加されています。上にある例を参照してください。ある動画に対する制御用動画として、動画と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある動画が使用されます(例:`video_dir/video1.mp4`と`control_dir/video1.mp4`または`control_dir/video1.mov`)。また、拡張子なしのディレクトリ内の、複数枚の画像を制御用動画として使用することもできます(例:`video_dir/video1.mp4`と`control_dir/video1`)。
362
+
363
+ データセット設定でメタデータJSONLファイルを使用する場合は、動画と制御用動画のパスを含める必要があります。制御用動画のパスは、複数枚の画像を含むディレクトリのパスでも構いません。
364
+
365
+ </details>
366
+
367
+ ## Architecture-specific Settings / アーキテクチャ固有の設定
368
+
369
+ The dataset configuration is shared across all architectures. However, some architectures may require additional settings or have specific requirements for the dataset.
370
+
371
+ ### FramePack
372
+
373
+ For FramePack, you can set the latent window size for training. It is recommended to set it to 9 for FramePack training. The default value is 9, so you can usually omit this setting.
374
+
375
+ ```toml
376
  [[datasets]]
377
+ fp_latent_window_size = 9
378
+ ```
379
 
380
+ <details>
381
+ <summary>日本語</summary>
382
 
383
+ 学習時のlatent window sizeを指定できます。FramePackの学習においては、9を指定することを推奨します。省略時は9が使用されますので、通常は省略して構いません。
384
+
385
+ </details>
386
+
387
+ ### FramePack One Frame Training
388
+
389
+ For the default one frame training of FramePack, you need to set the following parameters in the dataset configuration:
390
+
391
+ ```toml
392
+ [[datasets]]
393
+ fp_1f_clean_indices = [0]
394
+ fp_1f_target_index = 9
395
+ fp_1f_no_post = false
396
  ```
397
 
398
+ **Advanced Settings:**
399
+
400
+ **Note that these parameters are still experimental, and the optimal values are not yet known.** The parameters may also change in the future.
401
+
402
+ `fp_1f_clean_indices` sets the `clean_indices` value passed to the FramePack model. You can specify multiple indices. `fp_1f_target_index` sets the index of the frame to be trained (generated). `fp_1f_no_post` sets whether to add a zero value as `clean_latent_post`, default is `false` (add zero value).
403
+
404
+ The number of control images should match the number of indices specified in `fp_1f_clean_indices`.
405
+
406
+ The default values mean that the first image (control image) is at index `0`, and the target image (the changed image) is at index `9`.
407
+
408
+ For training with 1f-mc, set `fp_1f_clean_indices` to `[0, 1]` and `fp_1f_target_index` to `9` (or another value). This allows you to use multiple control images to train a single generated image. The control images will be two in this case.
409
+
410
+ ```toml
411
  [[datasets]]
412
+ fp_1f_clean_indices = [0, 1]
413
+ fp_1f_target_index = 9
414
+ fp_1f_no_post = false
415
+ ```
416
 
417
+ For training with kisekaeichi, set `fp_1f_clean_indices` to `[0, 10]` and `fp_1f_target_index` to `1` (or another value). This allows you to use the starting image (the image just before the generation section) and the image following the generation section (equivalent to `clean_latent_post`) to train the first image of the generated video. The control images will be two in this case. `fp_1f_no_post` should be set to `true`.
418
+
419
+ ```toml
420
+ [[datasets]]
421
+ fp_1f_clean_indices = [0, 10]
422
+ fp_1f_target_index = 1
423
+ fp_1f_no_post = true
424
+ ```
425
+
426
+ With `fp_1f_clean_indices` and `fp_1f_target_index`, you can specify any number of control images and any index of the target image for training.
427
+
428
+ If you set `fp_1f_no_post` to `false`, the `clean_latent_post_index` will be `1 + fp1_latent_window_size`.
429
 
430
+ You can also set the `no_2x` and `no_4x` options for cache scripts to disable the clean latents 2x and 4x.
431
 
432
+ The 2x indices are `1 + fp1_latent_window_size + 1` for two indices (usually `11, 12`), and the 4x indices are `1 + fp1_latent_window_size + 1 + 2` for sixteen indices (usually `13, 14, ..., 28`), regardless of `fp_1f_no_post` and `no_2x`, `no_4x` settings.
433
+
434
+ <details>
435
+ <summary>日本語</summary>
436
 
437
+ ※ **以下のパラメータは研究中で最適値はまだ不明です。** またパラメータ自体も変更される可能性があります。
438
+
439
+ デフォルトの1フレーム学習を行う場合、`fp_1f_clean_indices`に`[0]`を、`fp_1f_target_index`に`9`(または5から15程度の値)を、`no_post`に`false`を設定してください。(記述例は英語版ドキュメントを参照、以降同じ。)
440
+
441
+ **より高度な設定:**
442
+
443
+ `fp_1f_clean_indices`は、FramePackモデルに渡される `clean_indices` の値を設定します。複数指定が可能です。`fp_1f_target_index`は、学習(生成)対象のフレームのインデックスを設定します。`fp_1f_no_post`は、`clean_latent_post` をゼロ値で追加するかどうかを設定します(デフォルトは`false`で、ゼロ値で追加します)。
444
+
445
+ 制御画像の枚数は`fp_1f_clean_indices`に指定したインデックスの数とあわせてください。
446
+
447
+ デフォルトの1フレーム学習では、開始画像(制御画像)1枚をインデックス`0`、生成対象の画像(変化後の画像)をインデックス`9`に設定しています。
448
+
449
+ 1f-mcの学習を行う場合は、`fp_1f_clean_indices`に `[0, 1]`を、`fp_1f_target_index`に`9`を設定してください。これにより動画の先頭の2枚の制御画像を使用して、後続の1枚の生成画像を学習します。制御画像は2枚になります。
450
+
451
+ kisekaeichiの学習を行う場合は、`fp_1f_clean_indices`に `[0, 10]`を、`fp_1f_target_index`に`1`(または他の値)を設定してください。これは、開始画像(生成セクションの直前の画像)(`clean_latent_pre`に相当)と、生成セクションに続く1枚の画像(`clean_latent_post`に相当)を使用して、生成動画の先頭の画像(`target_index=1`)を学習します。制御画像は2枚になります。`f1_1f_no_post`は`true`に設定してください。
452
+
453
+ `fp_1f_clean_indices`と`fp_1f_target_index`を応用することで、任意の枚数の制御画像を、任意のインデックスを指定して学習することが可能です。
454
+
455
+ `fp_1f_no_post`を`false`に設定すると、`clean_latent_post_index`は `1 + fp1_latent_window_size` になります。
456
+
457
+ 推論時の `no_2x`、`no_4x`に対応する設定は、キャッシュスクリプトの引数で行えます。なお、2xのindexは `1 + fp1_latent_window_size + 1` からの2個(通常は`11, 12`)、4xのindexは `1 + fp1_latent_window_size + 1 + 2` からの16個になります(通常は`13, 14, ..., 28`)です。これらの値は`fp_1f_no_post`や`no_2x`, `no_4x`の設定に関わらず、常に同じです。
458
+
459
+ </details>
460
+
461
+ ## Specifications
462
 
463
  ```toml
464
  # general configurations
465
  [general]
466
+ resolution = [960, 544] # optional, [W, H], default is [960, 544]. This is the default resolution for all datasets
467
  caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
468
  batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
469
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
470
  enable_bucket = true # optional, default is false. Enable bucketing for datasets
471
  bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
472
 
473
+ ### Image Dataset
474
+
475
  # sample image dataset with caption text files
476
  [[datasets]]
477
  image_directory = "/path/to/image_dir"
478
  caption_extension = ".txt" # required for caption text files, if general caption extension is not set
479
  resolution = [960, 544] # required if general resolution is not set
480
  batch_size = 4 # optional, overwrite the default batch size
481
+ num_repeats = 1 # optional, overwrite the default num_repeats
482
  enable_bucket = false # optional, overwrite the default bucketing setting
483
  bucket_no_upscale = true # optional, overwrite the default bucketing setting
484
  cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
485
+ control_directory = "/path/to/control_dir" # optional, required for dataset with control images
486
 
487
  # sample image dataset with metadata **jsonl** file
488
  [[datasets]]
 
490
  resolution = [960, 544] # required if general resolution is not set
491
  cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
492
  # caption_extension is not required for metadata jsonl file
493
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
494
+
495
+ ### Video Dataset
496
 
497
  # sample video dataset with caption text files
498
  [[datasets]]
499
  video_directory = "/path/to/video_dir"
500
  caption_extension = ".txt" # required for caption text files, if general caption extension is not set
501
  resolution = [960, 544] # required if general resolution is not set
502
+
503
+ control_directory = "/path/to/control_dir" # optional, required for dataset with control images
504
+
505
+ # following configurations must be set in each [[datasets]] section for video datasets
506
+
507
  target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
508
+
509
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
510
+
511
  frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
512
  frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
513
  frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
514
+ max_frames = 129 # optional, default is 129. Maximum number of frames to extract, available for "full" frame extraction
515
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
516
 
517
  # sample video dataset with metadata jsonl file
518
  [[datasets]]
519
  video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
520
+
521
  target_frames = [1, 79]
522
+
523
  cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
524
+ # frame_extraction, frame_stride, frame_sample, max_frames are also available for metadata jsonl file
525
  ```
526
 
527
+ <!--
528
  # sample image dataset with lance
529
  [[datasets]]
530
  image_lance_dataset = "/path/to/lance_dataset"
531
  resolution = [960, 544] # required if general resolution is not set
532
  # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
533
+ -->
534
 
535
  The metadata with .json file will be supported in the near future.
536
 
537
 
538
 
 
 
dataset/image_video_dataset.py CHANGED
@@ -5,7 +5,7 @@ import math
5
  import os
6
  import random
7
  import time
8
- from typing import Optional, Sequence, Tuple, Union
9
 
10
  import numpy as np
11
  import torch
@@ -76,6 +76,8 @@ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
76
  ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
77
  ARCHITECTURE_WAN = "wan"
78
  ARCHITECTURE_WAN_FULL = "wan"
 
 
79
 
80
 
81
  def glob_images(directory, base="*"):
@@ -109,6 +111,8 @@ def divisible_by(num: int, divisor: int) -> int:
109
  def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
110
  """
111
  Resize the image to the bucket resolution.
 
 
112
  """
113
  is_pil_image = isinstance(image, Image.Image)
114
  if is_pil_image:
@@ -120,23 +124,21 @@ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: t
120
  return np.array(image) if is_pil_image else image
121
 
122
  bucket_width, bucket_height = bucket_reso
123
- if bucket_width == image_width or bucket_height == image_height:
124
- image = np.array(image) if is_pil_image else image
 
 
 
 
 
 
 
 
 
 
125
  else:
126
- # resize the image to the bucket resolution to match the short side
127
- scale_width = bucket_width / image_width
128
- scale_height = bucket_height / image_height
129
- scale = max(scale_width, scale_height)
130
- image_width = int(image_width * scale + 0.5)
131
- image_height = int(image_height * scale + 0.5)
132
-
133
- if scale > 1:
134
- image = Image.fromarray(image) if not is_pil_image else image
135
- image = image.resize((image_width, image_height), Image.LANCZOS)
136
- image = np.array(image)
137
- else:
138
- image = np.array(image) if is_pil_image else image
139
- image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
140
 
141
  # crop the image to the bucket resolution
142
  crop_left = (image_width - bucket_width) // 2
@@ -151,7 +153,7 @@ class ItemInfo:
151
  item_key: str,
152
  caption: str,
153
  original_size: tuple[int, int],
154
- bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
155
  frame_count: Optional[int] = None,
156
  content: Optional[np.ndarray] = None,
157
  latent_cache_path: Optional[str] = None,
@@ -165,11 +167,20 @@ class ItemInfo:
165
  self.latent_cache_path = latent_cache_path
166
  self.text_encoder_output_cache_path: Optional[str] = None
167
 
 
 
 
 
 
 
 
 
 
168
  def __str__(self) -> str:
169
  return (
170
  f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
171
  + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
172
- + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
173
  )
174
 
175
 
@@ -181,7 +192,7 @@ class ItemInfo:
181
 
182
 
183
  def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
184
- """HunyuanVideo architecture only"""
185
  assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
186
 
187
  _, F, H, W = latent.shape
@@ -192,7 +203,11 @@ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
192
 
193
 
194
  def save_latent_cache_wan(
195
- item_info: ItemInfo, latent: torch.Tensor, clip_embed: Optional[torch.Tensor], image_latent: Optional[torch.Tensor]
 
 
 
 
196
  ):
197
  """Wan architecture only"""
198
  assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
@@ -207,9 +222,51 @@ def save_latent_cache_wan(
207
  if image_latent is not None:
208
  sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
209
 
 
 
 
210
  save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
214
  metadata = {
215
  "architecture": arch_fullname,
@@ -260,6 +317,20 @@ def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor)
260
  save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
261
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
264
  for key, value in sd.items():
265
  # NaN check and show warning, replace NaN with 0
@@ -299,6 +370,7 @@ def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, tor
299
  class BucketSelector:
300
  RESOLUTION_STEPS_HUNYUAN = 16
301
  RESOLUTION_STEPS_WAN = 16
 
302
 
303
  def __init__(
304
  self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
@@ -311,6 +383,8 @@ class BucketSelector:
311
  self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
312
  elif self.architecture == ARCHITECTURE_WAN:
313
  self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
 
 
314
  else:
315
  raise ValueError(f"Invalid architecture: {self.architecture}")
316
 
@@ -358,48 +432,142 @@ def load_video(
358
  end_frame: Optional[int] = None,
359
  bucket_selector: Optional[BucketSelector] = None,
360
  bucket_reso: Optional[tuple[int, int]] = None,
 
 
361
  ) -> list[np.ndarray]:
362
  """
363
  bucket_reso: if given, resize the video to the bucket resolution, (width, height)
364
  """
365
- container = av.open(video_path)
366
- video = []
367
- for i, frame in enumerate(container.decode(video=0)):
368
- if start_frame is not None and i < start_frame:
369
- continue
370
- if end_frame is not None and i >= end_frame:
371
- break
372
- frame = frame.to_image()
373
-
374
- if bucket_selector is not None and bucket_reso is None:
375
- bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
376
-
377
- if bucket_reso is not None:
378
- frame = resize_image_to_bucket(frame, bucket_reso)
 
 
 
 
 
 
 
379
  else:
380
- frame = np.array(frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- video.append(frame)
383
- container.close()
384
  return video
385
 
386
 
387
  class BucketBatchManager:
388
 
389
- def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
390
  self.batch_size = batch_size
391
  self.buckets = bucketed_item_info
392
  self.bucket_resos = list(self.buckets.keys())
393
  self.bucket_resos.sort()
394
 
395
- self.bucket_batch_indices = []
 
396
  for bucket_reso in self.bucket_resos:
397
  bucket = self.buckets[bucket_reso]
398
  num_batches = math.ceil(len(bucket) / self.batch_size)
399
  for i in range(num_batches):
400
  self.bucket_batch_indices.append((bucket_reso, i))
401
 
402
- self.shuffle()
 
403
 
404
  def show_bucket_info(self):
405
  for bucket_reso in self.bucket_resos:
@@ -409,8 +577,11 @@ class BucketBatchManager:
409
  logger.info(f"total batches: {len(self)}")
410
 
411
  def shuffle(self):
 
412
  for bucket in self.buckets.values():
413
  random.shuffle(bucket)
 
 
414
  random.shuffle(self.bucket_batch_indices)
415
 
416
  def __len__(self):
@@ -460,7 +631,8 @@ class BucketBatchManager:
460
 
461
  class ContentDatasource:
462
  def __init__(self):
463
- self.caption_only = False
 
464
 
465
  def set_caption_only(self, caption_only: bool):
466
  self.caption_only = caption_only
@@ -498,10 +670,18 @@ class ImageDatasource(ContentDatasource):
498
 
499
 
500
  class ImageDirectoryDatasource(ImageDatasource):
501
- def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
 
 
 
 
 
 
502
  super().__init__()
503
  self.image_directory = image_directory
504
  self.caption_extension = caption_extension
 
 
505
  self.current_idx = 0
506
 
507
  # glob images
@@ -509,19 +689,68 @@ class ImageDirectoryDatasource(ImageDatasource):
509
  self.image_paths = glob_images(self.image_directory)
510
  logger.info(f"found {len(self.image_paths)} images")
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  def is_indexable(self):
513
  return True
514
 
515
  def __len__(self):
516
  return len(self.image_paths)
517
 
518
- def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
519
  image_path = self.image_paths[idx]
520
  image = Image.open(image_path).convert("RGB")
521
 
522
  _, caption = self.get_caption(idx)
523
 
524
- return image_path, image, caption
 
 
 
 
 
 
 
 
 
525
 
526
  def get_caption(self, idx: int) -> tuple[str, str]:
527
  image_path = self.image_paths[idx]
@@ -559,9 +788,10 @@ class ImageDirectoryDatasource(ImageDatasource):
559
 
560
 
561
  class ImageJsonlDatasource(ImageDatasource):
562
- def __init__(self, image_jsonl_file: str):
563
  super().__init__()
564
  self.image_jsonl_file = image_jsonl_file
 
565
  self.current_idx = 0
566
 
567
  # load jsonl
@@ -577,20 +807,55 @@ class ImageJsonlDatasource(ImageDatasource):
577
  self.data.append(data)
578
  logger.info(f"loaded {len(self.data)} images")
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  def is_indexable(self):
581
  return True
582
 
583
  def __len__(self):
584
  return len(self.data)
585
 
586
- def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
587
  data = self.data[idx]
588
  image_path = data["image_path"]
589
  image = Image.open(image_path).convert("RGB")
590
 
591
  caption = data["caption"]
592
 
593
- return image_path, image, caption
 
 
 
 
 
 
 
 
 
 
594
 
595
  def get_caption(self, idx: int) -> tuple[str, str]:
596
  data = self.data[idx]
@@ -634,6 +899,9 @@ class VideoDatasource(ContentDatasource):
634
 
635
  self.bucket_selector = None
636
 
 
 
 
637
  def __len__(self):
638
  raise NotImplementedError
639
 
@@ -650,9 +918,27 @@ class VideoDatasource(ContentDatasource):
650
  end_frame = end_frame if end_frame is not None else self.end_frame
651
  bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
652
 
653
- video = load_video(video_path, start_frame, end_frame, bucket_selector)
 
 
654
  return video
655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
657
  self.start_frame = start_frame
658
  self.end_frame = end_frame
@@ -660,6 +946,10 @@ class VideoDatasource(ContentDatasource):
660
  def set_bucket_selector(self, bucket_selector: BucketSelector):
661
  self.bucket_selector = bucket_selector
662
 
 
 
 
 
663
  def __iter__(self):
664
  raise NotImplementedError
665
 
@@ -668,17 +958,58 @@ class VideoDatasource(ContentDatasource):
668
 
669
 
670
  class VideoDirectoryDatasource(VideoDatasource):
671
- def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
672
  super().__init__()
673
  self.video_directory = video_directory
674
  self.caption_extension = caption_extension
 
675
  self.current_idx = 0
676
 
677
- # glob images
678
- logger.info(f"glob images in {self.video_directory}")
679
  self.video_paths = glob_videos(self.video_directory)
680
  logger.info(f"found {len(self.video_paths)} videos")
681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  def is_indexable(self):
683
  return True
684
 
@@ -691,13 +1022,18 @@ class VideoDirectoryDatasource(VideoDatasource):
691
  start_frame: Optional[int] = None,
692
  end_frame: Optional[int] = None,
693
  bucket_selector: Optional[BucketSelector] = None,
694
- ) -> tuple[str, list[Image.Image], str]:
695
  video_path = self.video_paths[idx]
696
  video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
697
 
698
  _, caption = self.get_caption(idx)
699
 
700
- return video_path, video, caption
 
 
 
 
 
701
 
702
  def get_caption(self, idx: int) -> tuple[str, str]:
703
  video_path = self.video_paths[idx]
@@ -747,6 +1083,16 @@ class VideoJsonlDatasource(VideoDatasource):
747
  self.data.append(data)
748
  logger.info(f"loaded {len(self.data)} videos")
749
 
 
 
 
 
 
 
 
 
 
 
750
  def is_indexable(self):
751
  return True
752
 
@@ -759,14 +1105,19 @@ class VideoJsonlDatasource(VideoDatasource):
759
  start_frame: Optional[int] = None,
760
  end_frame: Optional[int] = None,
761
  bucket_selector: Optional[BucketSelector] = None,
762
- ) -> tuple[str, list[Image.Image], str]:
763
  data = self.data[idx]
764
  video_path = data["video_path"]
765
  video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
766
 
767
  caption = data["caption"]
768
 
769
- return video_path, video, caption
 
 
 
 
 
770
 
771
  def get_caption(self, idx: int) -> tuple[str, str]:
772
  data = self.data[idx]
@@ -973,7 +1324,12 @@ class ImageDataset(BaseDataset):
973
  bucket_no_upscale: bool,
974
  image_directory: Optional[str] = None,
975
  image_jsonl_file: Optional[str] = None,
 
976
  cache_directory: Optional[str] = None,
 
 
 
 
977
  debug_dataset: bool = False,
978
  architecture: str = "no_default",
979
  ):
@@ -990,10 +1346,22 @@ class ImageDataset(BaseDataset):
990
  )
991
  self.image_directory = image_directory
992
  self.image_jsonl_file = image_jsonl_file
 
 
 
 
 
 
 
 
 
 
993
  if image_directory is not None:
994
- self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
 
 
995
  elif image_jsonl_file is not None:
996
- self.datasource = ImageJsonlDatasource(image_jsonl_file)
997
  else:
998
  raise ValueError("image_directory or image_jsonl_file must be specified")
999
 
@@ -1002,6 +1370,7 @@ class ImageDataset(BaseDataset):
1002
 
1003
  self.batch_manager = None
1004
  self.num_train_items = 0
 
1005
 
1006
  def get_metadata(self):
1007
  metadata = super().get_metadata()
@@ -1009,6 +1378,9 @@ class ImageDataset(BaseDataset):
1009
  metadata["image_directory"] = os.path.basename(self.image_directory)
1010
  if self.image_jsonl_file is not None:
1011
  metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
 
 
 
1012
  return metadata
1013
 
1014
  def get_total_image_count(self):
@@ -1033,12 +1405,27 @@ class ImageDataset(BaseDataset):
1033
  break # submit batch if possible
1034
 
1035
  for future in completed_futures:
1036
- original_size, item_key, image, caption = future.result()
1037
  bucket_height, bucket_width = image.shape[:2]
1038
  bucket_reso = (bucket_width, bucket_height)
1039
 
1040
  item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
1041
  item_info.latent_cache_path = self.get_latent_cache_path(item_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1042
 
1043
  if bucket_reso not in batches:
1044
  batches[bucket_reso] = []
@@ -1061,14 +1448,21 @@ class ImageDataset(BaseDataset):
1061
  for fetch_op in self.datasource:
1062
 
1063
  # fetch and resize image in a separate thread
1064
- def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
1065
- image_key, image, caption = op()
1066
  image: Image.Image
1067
  image_size = image.size
1068
 
1069
  bucket_reso = buckset_selector.get_bucket_resolution(image_size)
1070
- image = resize_image_to_bucket(image, bucket_reso)
1071
- return image_size, image_key, image, caption
 
 
 
 
 
 
 
1072
 
1073
  future = executor.submit(fetch_and_resize, fetch_op)
1074
  futures.append(future)
@@ -1113,6 +1507,15 @@ class ImageDataset(BaseDataset):
1113
  continue
1114
 
1115
  bucket_reso = bucket_selector.get_bucket_resolution(image_size)
 
 
 
 
 
 
 
 
 
1116
  item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
1117
  item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1118
 
@@ -1142,6 +1545,10 @@ class ImageDataset(BaseDataset):
1142
 
1143
 
1144
  class VideoDataset(BaseDataset):
 
 
 
 
1145
  def __init__(
1146
  self,
1147
  resolution: Tuple[int, int],
@@ -1154,9 +1561,13 @@ class VideoDataset(BaseDataset):
1154
  frame_stride: Optional[int] = 1,
1155
  frame_sample: Optional[int] = 1,
1156
  target_frames: Optional[list[int]] = None,
 
 
1157
  video_directory: Optional[str] = None,
1158
  video_jsonl_file: Optional[str] = None,
 
1159
  cache_directory: Optional[str] = None,
 
1160
  debug_dataset: bool = False,
1161
  architecture: str = "no_default",
1162
  ):
@@ -1173,13 +1584,42 @@ class VideoDataset(BaseDataset):
1173
  )
1174
  self.video_directory = video_directory
1175
  self.video_jsonl_file = video_jsonl_file
1176
- self.target_frames = target_frames
1177
  self.frame_extraction = frame_extraction
1178
  self.frame_stride = frame_stride
1179
  self.frame_sample = frame_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1180
 
1181
  if video_directory is not None:
1182
- self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
1183
  elif video_jsonl_file is not None:
1184
  self.datasource = VideoJsonlDatasource(video_jsonl_file)
1185
 
@@ -1195,6 +1635,7 @@ class VideoDataset(BaseDataset):
1195
 
1196
  self.batch_manager = None
1197
  self.num_train_items = 0
 
1198
 
1199
  def get_metadata(self):
1200
  metadata = super().get_metadata()
@@ -1202,20 +1643,29 @@ class VideoDataset(BaseDataset):
1202
  metadata["video_directory"] = os.path.basename(self.video_directory)
1203
  if self.video_jsonl_file is not None:
1204
  metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
 
 
1205
  metadata["frame_extraction"] = self.frame_extraction
1206
  metadata["frame_stride"] = self.frame_stride
1207
  metadata["frame_sample"] = self.frame_sample
1208
  metadata["target_frames"] = self.target_frames
 
 
 
1209
  return metadata
1210
 
1211
  def retrieve_latent_cache_batches(self, num_workers: int):
1212
  buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
1213
  self.datasource.set_bucket_selector(buckset_selector)
 
 
 
 
1214
 
1215
  executor = ThreadPoolExecutor(max_workers=num_workers)
1216
 
1217
- # key: (width, height, frame_count), value: [ItemInfo]
1218
- batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1219
  futures = []
1220
 
1221
  def aggregate_future(consume_all: bool = False):
@@ -1229,13 +1679,25 @@ class VideoDataset(BaseDataset):
1229
  break # submit batch if possible
1230
 
1231
  for future in completed_futures:
1232
- original_frame_size, video_key, video, caption = future.result()
1233
 
1234
  frame_count = len(video)
1235
  video = np.stack(video, axis=0)
1236
  height, width = video.shape[1:3]
1237
  bucket_reso = (width, height) # already resized
1238
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
  crop_pos_and_frames = []
1240
  if self.frame_extraction == "head":
1241
  for target_frame in self.target_frames:
@@ -1260,6 +1722,11 @@ class VideoDataset(BaseDataset):
1260
  frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1261
  for i in frame_indices:
1262
  crop_pos_and_frames.append((i, target_frame))
 
 
 
 
 
1263
  else:
1264
  raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1265
 
@@ -1269,10 +1736,21 @@ class VideoDataset(BaseDataset):
1269
  item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1270
  batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1271
 
 
 
 
 
 
 
 
 
 
1272
  item_info = ItemInfo(
1273
  item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1274
  )
1275
  item_info.latent_cache_path = self.get_latent_cache_path(item_info)
 
 
1276
 
1277
  batch = batches.get(batch_key, [])
1278
  batch.append(item_info)
@@ -1293,8 +1771,15 @@ class VideoDataset(BaseDataset):
1293
 
1294
  for operator in self.datasource:
1295
 
1296
- def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
1297
- video_key, video, caption = op()
 
 
 
 
 
 
 
1298
  video: list[np.ndarray]
1299
  frame_size = (video[0].shape[1], video[0].shape[0])
1300
 
@@ -1302,7 +1787,11 @@ class VideoDataset(BaseDataset):
1302
  bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1303
  video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1304
 
1305
- return frame_size, video_key, video, caption
 
 
 
 
1306
 
1307
  future = executor.submit(fetch_and_resize, operator)
1308
  futures.append(future)
@@ -1340,7 +1829,7 @@ class VideoDataset(BaseDataset):
1340
  image_width, image_height = map(int, image_size.split("x"))
1341
  image_size = (image_width, image_height)
1342
 
1343
- frame_pos, frame_count = tokens[-3].split("-")
1344
  frame_pos, frame_count = int(frame_pos), int(frame_count)
1345
 
1346
  item_key = "_".join(tokens[:-3])
 
5
  import os
6
  import random
7
  import time
8
+ from typing import Any, Optional, Sequence, Tuple, Union
9
 
10
  import numpy as np
11
  import torch
 
76
  ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
77
  ARCHITECTURE_WAN = "wan"
78
  ARCHITECTURE_WAN_FULL = "wan"
79
+ ARCHITECTURE_FRAMEPACK = "fp"
80
+ ARCHITECTURE_FRAMEPACK_FULL = "framepack"
81
 
82
 
83
  def glob_images(directory, base="*"):
 
111
  def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
112
  """
113
  Resize the image to the bucket resolution.
114
+
115
+ bucket_reso: **(width, height)**
116
  """
117
  is_pil_image = isinstance(image, Image.Image)
118
  if is_pil_image:
 
124
  return np.array(image) if is_pil_image else image
125
 
126
  bucket_width, bucket_height = bucket_reso
127
+
128
+ # resize the image to the bucket resolution to match the short side
129
+ scale_width = bucket_width / image_width
130
+ scale_height = bucket_height / image_height
131
+ scale = max(scale_width, scale_height)
132
+ image_width = int(image_width * scale + 0.5)
133
+ image_height = int(image_height * scale + 0.5)
134
+
135
+ if scale > 1:
136
+ image = Image.fromarray(image) if not is_pil_image else image
137
+ image = image.resize((image_width, image_height), Image.LANCZOS)
138
+ image = np.array(image)
139
  else:
140
+ image = np.array(image) if is_pil_image else image
141
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # crop the image to the bucket resolution
144
  crop_left = (image_width - bucket_width) // 2
 
153
  item_key: str,
154
  caption: str,
155
  original_size: tuple[int, int],
156
+ bucket_size: Optional[tuple[Any]] = None,
157
  frame_count: Optional[int] = None,
158
  content: Optional[np.ndarray] = None,
159
  latent_cache_path: Optional[str] = None,
 
167
  self.latent_cache_path = latent_cache_path
168
  self.text_encoder_output_cache_path: Optional[str] = None
169
 
170
+ # np.ndarray for video, list[np.ndarray] for image with multiple controls
171
+ self.control_content: Optional[Union[np.ndarray, list[np.ndarray]]] = None
172
+
173
+ # FramePack architecture specific
174
+ self.fp_latent_window_size: Optional[int] = None
175
+ self.fp_1f_clean_indices: Optional[list[int]] = None # indices of clean latents for 1f
176
+ self.fp_1f_target_index: Optional[int] = None # target index for 1f clean latents
177
+ self.fp_1f_no_post: Optional[bool] = None # whether to add zero values as clean latent post
178
+
179
  def __str__(self) -> str:
180
  return (
181
  f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
182
  + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
183
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path}, content={self.content.shape if self.content is not None else None})"
184
  )
185
 
186
 
 
192
 
193
 
194
  def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
195
+ """HunyuanVideo architecture only. HunyuanVideo doesn't support I2V and control latents"""
196
  assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
197
 
198
  _, F, H, W = latent.shape
 
203
 
204
 
205
  def save_latent_cache_wan(
206
+ item_info: ItemInfo,
207
+ latent: torch.Tensor,
208
+ clip_embed: Optional[torch.Tensor],
209
+ image_latent: Optional[torch.Tensor],
210
+ control_latent: Optional[torch.Tensor],
211
  ):
212
  """Wan architecture only"""
213
  assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
 
222
  if image_latent is not None:
223
  sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
224
 
225
+ if control_latent is not None:
226
+ sd[f"latents_control_{F}x{H}x{W}_{dtype_str}"] = control_latent.detach().cpu()
227
+
228
  save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
229
 
230
 
231
+ def save_latent_cache_framepack(
232
+ item_info: ItemInfo,
233
+ latent: torch.Tensor,
234
+ latent_indices: torch.Tensor,
235
+ clean_latents: torch.Tensor,
236
+ clean_latent_indices: torch.Tensor,
237
+ clean_latents_2x: torch.Tensor,
238
+ clean_latent_2x_indices: torch.Tensor,
239
+ clean_latents_4x: torch.Tensor,
240
+ clean_latent_4x_indices: torch.Tensor,
241
+ image_embeddings: torch.Tensor,
242
+ ):
243
+ """FramePack architecture only"""
244
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
245
+
246
+ _, F, H, W = latent.shape
247
+ dtype_str = dtype_to_str(latent.dtype)
248
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu().contiguous()}
249
+
250
+ # `latents_xxx` must have {F, H, W} suffix
251
+ indices_dtype_str = dtype_to_str(latent_indices.dtype)
252
+ sd[f"image_embeddings_{dtype_str}"] = image_embeddings.detach().cpu() # image embeddings dtype is same as latents dtype
253
+ sd[f"latent_indices_{indices_dtype_str}"] = latent_indices.detach().cpu()
254
+ sd[f"clean_latent_indices_{indices_dtype_str}"] = clean_latent_indices.detach().cpu()
255
+ sd[f"latents_clean_{F}x{H}x{W}_{dtype_str}"] = clean_latents.detach().cpu().contiguous()
256
+ if clean_latent_2x_indices is not None:
257
+ sd[f"clean_latent_2x_indices_{indices_dtype_str}"] = clean_latent_2x_indices.detach().cpu()
258
+ if clean_latents_2x is not None:
259
+ sd[f"latents_clean_2x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_2x.detach().cpu().contiguous()
260
+ if clean_latent_4x_indices is not None:
261
+ sd[f"clean_latent_4x_indices_{indices_dtype_str}"] = clean_latent_4x_indices.detach().cpu()
262
+ if clean_latents_4x is not None:
263
+ sd[f"latents_clean_4x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_4x.detach().cpu().contiguous()
264
+
265
+ # for key, value in sd.items():
266
+ # print(f"{key}: {value.shape}")
267
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
268
+
269
+
270
  def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
271
  metadata = {
272
  "architecture": arch_fullname,
 
317
  save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
318
 
319
 
320
+ def save_text_encoder_output_cache_framepack(
321
+ item_info: ItemInfo, llama_vec: torch.Tensor, llama_attention_mask: torch.Tensor, clip_l_pooler: torch.Tensor
322
+ ):
323
+ """FramePack architecture only."""
324
+ sd = {}
325
+ dtype_str = dtype_to_str(llama_vec.dtype)
326
+ sd[f"llama_vec_{dtype_str}"] = llama_vec.detach().cpu()
327
+ sd[f"llama_attention_mask"] = llama_attention_mask.detach().cpu()
328
+ dtype_str = dtype_to_str(clip_l_pooler.dtype)
329
+ sd[f"clip_l_pooler_{dtype_str}"] = clip_l_pooler.detach().cpu()
330
+
331
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
332
+
333
+
334
  def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
335
  for key, value in sd.items():
336
  # NaN check and show warning, replace NaN with 0
 
370
  class BucketSelector:
371
  RESOLUTION_STEPS_HUNYUAN = 16
372
  RESOLUTION_STEPS_WAN = 16
373
+ RESOLUTION_STEPS_FRAMEPACK = 16
374
 
375
  def __init__(
376
  self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
 
383
  self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
384
  elif self.architecture == ARCHITECTURE_WAN:
385
  self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
386
+ elif self.architecture == ARCHITECTURE_FRAMEPACK:
387
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_FRAMEPACK
388
  else:
389
  raise ValueError(f"Invalid architecture: {self.architecture}")
390
 
 
432
  end_frame: Optional[int] = None,
433
  bucket_selector: Optional[BucketSelector] = None,
434
  bucket_reso: Optional[tuple[int, int]] = None,
435
+ source_fps: Optional[float] = None,
436
+ target_fps: Optional[float] = None,
437
  ) -> list[np.ndarray]:
438
  """
439
  bucket_reso: if given, resize the video to the bucket resolution, (width, height)
440
  """
441
+ if source_fps is None or target_fps is None:
442
+ if os.path.isfile(video_path):
443
+ container = av.open(video_path)
444
+ video = []
445
+ for i, frame in enumerate(container.decode(video=0)):
446
+ if start_frame is not None and i < start_frame:
447
+ continue
448
+ if end_frame is not None and i >= end_frame:
449
+ break
450
+ frame = frame.to_image()
451
+
452
+ if bucket_selector is not None and bucket_reso is None:
453
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
454
+
455
+ if bucket_reso is not None:
456
+ frame = resize_image_to_bucket(frame, bucket_reso)
457
+ else:
458
+ frame = np.array(frame)
459
+
460
+ video.append(frame)
461
+ container.close()
462
  else:
463
+ # load images in the directory
464
+ image_files = glob_images(video_path)
465
+ image_files.sort()
466
+ video = []
467
+ for i in range(len(image_files)):
468
+ if start_frame is not None and i < start_frame:
469
+ continue
470
+ if end_frame is not None and i >= end_frame:
471
+ break
472
+
473
+ image_file = image_files[i]
474
+ image = Image.open(image_file).convert("RGB")
475
+
476
+ if bucket_selector is not None and bucket_reso is None:
477
+ bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
478
+ image = np.array(image)
479
+ if bucket_reso is not None:
480
+ image = resize_image_to_bucket(image, bucket_reso)
481
+
482
+ video.append(image)
483
+ else:
484
+ # drop frames to match the target fps TODO commonize this code with the above if this works
485
+ frame_index_delta = target_fps / source_fps # example: 16 / 30 = 0.5333
486
+ if os.path.isfile(video_path):
487
+ container = av.open(video_path)
488
+ video = []
489
+ frame_index_with_fraction = 0.0
490
+ previous_frame_index = -1
491
+ for i, frame in enumerate(container.decode(video=0)):
492
+ target_frame_index = int(frame_index_with_fraction)
493
+ frame_index_with_fraction += frame_index_delta
494
+
495
+ if target_frame_index == previous_frame_index: # drop this frame
496
+ continue
497
+
498
+ # accept this frame
499
+ previous_frame_index = target_frame_index
500
+
501
+ if start_frame is not None and target_frame_index < start_frame:
502
+ continue
503
+ if end_frame is not None and target_frame_index >= end_frame:
504
+ break
505
+ frame = frame.to_image()
506
+
507
+ if bucket_selector is not None and bucket_reso is None:
508
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
509
+
510
+ if bucket_reso is not None:
511
+ frame = resize_image_to_bucket(frame, bucket_reso)
512
+ else:
513
+ frame = np.array(frame)
514
+
515
+ video.append(frame)
516
+ container.close()
517
+ else:
518
+ # load images in the directory
519
+ image_files = glob_images(video_path)
520
+ image_files.sort()
521
+ video = []
522
+ frame_index_with_fraction = 0.0
523
+ previous_frame_index = -1
524
+ for i in range(len(image_files)):
525
+ target_frame_index = int(frame_index_with_fraction)
526
+ frame_index_with_fraction += frame_index_delta
527
+
528
+ if target_frame_index == previous_frame_index: # drop this frame
529
+ continue
530
+
531
+ # accept this frame
532
+ previous_frame_index = target_frame_index
533
+
534
+ if start_frame is not None and target_frame_index < start_frame:
535
+ continue
536
+ if end_frame is not None and target_frame_index >= end_frame:
537
+ break
538
+
539
+ image_file = image_files[i]
540
+ image = Image.open(image_file).convert("RGB")
541
+
542
+ if bucket_selector is not None and bucket_reso is None:
543
+ bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
544
+ image = np.array(image)
545
+ if bucket_reso is not None:
546
+ image = resize_image_to_bucket(image, bucket_reso)
547
+
548
+ video.append(image)
549
 
 
 
550
  return video
551
 
552
 
553
  class BucketBatchManager:
554
 
555
+ def __init__(self, bucketed_item_info: dict[tuple[Any], list[ItemInfo]], batch_size: int):
556
  self.batch_size = batch_size
557
  self.buckets = bucketed_item_info
558
  self.bucket_resos = list(self.buckets.keys())
559
  self.bucket_resos.sort()
560
 
561
+ # indices for enumerating batches. each batch is reso + batch_idx. reso is (width, height) or (width, height, frames)
562
+ self.bucket_batch_indices: list[tuple[tuple[Any], int]] = []
563
  for bucket_reso in self.bucket_resos:
564
  bucket = self.buckets[bucket_reso]
565
  num_batches = math.ceil(len(bucket) / self.batch_size)
566
  for i in range(num_batches):
567
  self.bucket_batch_indices.append((bucket_reso, i))
568
 
569
+ # do no shuffle here to avoid multiple datasets have different order
570
+ # self.shuffle()
571
 
572
  def show_bucket_info(self):
573
  for bucket_reso in self.bucket_resos:
 
577
  logger.info(f"total batches: {len(self)}")
578
 
579
  def shuffle(self):
580
+ # shuffle each bucket
581
  for bucket in self.buckets.values():
582
  random.shuffle(bucket)
583
+
584
+ # shuffle the order of batches
585
  random.shuffle(self.bucket_batch_indices)
586
 
587
  def __len__(self):
 
631
 
632
  class ContentDatasource:
633
  def __init__(self):
634
+ self.caption_only = False # set to True to only fetch caption for Text Encoder caching
635
+ self.has_control = False
636
 
637
  def set_caption_only(self, caption_only: bool):
638
  self.caption_only = caption_only
 
670
 
671
 
672
  class ImageDirectoryDatasource(ImageDatasource):
673
+ def __init__(
674
+ self,
675
+ image_directory: str,
676
+ caption_extension: Optional[str] = None,
677
+ control_directory: Optional[str] = None,
678
+ control_count_per_image: int = 1,
679
+ ):
680
  super().__init__()
681
  self.image_directory = image_directory
682
  self.caption_extension = caption_extension
683
+ self.control_directory = control_directory
684
+ self.control_count_per_image = control_count_per_image
685
  self.current_idx = 0
686
 
687
  # glob images
 
689
  self.image_paths = glob_images(self.image_directory)
690
  logger.info(f"found {len(self.image_paths)} images")
691
 
692
+ # glob control images if specified
693
+ if self.control_directory is not None:
694
+ logger.info(f"glob control images in {self.control_directory}")
695
+ self.has_control = True
696
+ self.control_paths = {}
697
+ for image_path in self.image_paths:
698
+ image_basename = os.path.basename(image_path)
699
+ image_basename_no_ext = os.path.splitext(image_basename)[0]
700
+ potential_paths = glob.glob(os.path.join(self.control_directory, os.path.splitext(image_basename)[0] + "*.*"))
701
+ if potential_paths:
702
+ # sort by the digits (`_0000`) suffix, prefer the one without the suffix
703
+ def sort_key(path):
704
+ basename = os.path.basename(path)
705
+ basename_no_ext = os.path.splitext(basename)[0]
706
+ if image_basename_no_ext == basename_no_ext: # prefer the one without suffix
707
+ return 0
708
+ digits_suffix = basename_no_ext.rsplit("_", 1)[-1]
709
+ if not digits_suffix.isdigit():
710
+ raise ValueError(f"Invalid digits suffix in {basename_no_ext}")
711
+ return int(digits_suffix) + 1
712
+
713
+ potential_paths.sort(key=sort_key)
714
+ if len(potential_paths) < control_count_per_image:
715
+ logger.error(
716
+ f"Not enough control images for {image_path}: found {len(potential_paths)}, expected {control_count_per_image}"
717
+ )
718
+ raise ValueError(
719
+ f"Not enough control images for {image_path}: found {len(potential_paths)}, expected {control_count_per_image}"
720
+ )
721
+
722
+ # take the first `control_count_per_image` paths
723
+ self.control_paths[image_path] = potential_paths[:control_count_per_image]
724
+ logger.info(f"found {len(self.control_paths)} matching control images")
725
+
726
+ missing_controls = len(self.image_paths) - len(self.control_paths)
727
+ if missing_controls > 0:
728
+ missing_control_paths = set(self.image_paths) - set(self.control_paths.keys())
729
+ logger.error(f"Could not find matching control images for {missing_controls} images: {missing_control_paths}")
730
+ raise ValueError(f"Could not find matching control images for {missing_controls} images")
731
+
732
  def is_indexable(self):
733
  return True
734
 
735
  def __len__(self):
736
  return len(self.image_paths)
737
 
738
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]:
739
  image_path = self.image_paths[idx]
740
  image = Image.open(image_path).convert("RGB")
741
 
742
  _, caption = self.get_caption(idx)
743
 
744
+ controls = None
745
+ if self.has_control:
746
+ controls = []
747
+ for control_path in self.control_paths[image_path]:
748
+ control = Image.open(control_path)
749
+ if control.mode != "RGB" and control.mode != "RGBA":
750
+ control = control.convert("RGB")
751
+ controls.append(control)
752
+
753
+ return image_path, image, caption, controls
754
 
755
  def get_caption(self, idx: int) -> tuple[str, str]:
756
  image_path = self.image_paths[idx]
 
788
 
789
 
790
  class ImageJsonlDatasource(ImageDatasource):
791
+ def __init__(self, image_jsonl_file: str, control_count_per_image: int = 1):
792
  super().__init__()
793
  self.image_jsonl_file = image_jsonl_file
794
+ self.control_count_per_image = control_count_per_image
795
  self.current_idx = 0
796
 
797
  # load jsonl
 
807
  self.data.append(data)
808
  logger.info(f"loaded {len(self.data)} images")
809
 
810
+ # Normalize control paths
811
+ for item in self.data:
812
+ if "control_path" in item:
813
+ item["control_path_0"] = item.pop("control_path")
814
+
815
+ # Ensure control paths are named consistently, from control_path_0000 to control_path_0, control_path_1, etc.
816
+ control_path_keys = [key for key in item.keys() if key.startswith("control_path_")]
817
+ control_path_keys.sort(key=lambda x: int(x.split("_")[-1]))
818
+ for i, key in enumerate(control_path_keys):
819
+ if key != f"control_path_{i}":
820
+ item[f"control_path_{i}"] = item.pop(key)
821
+
822
+ # Check if there are control paths in the JSONL
823
+ self.has_control = any("control_path_0" in item for item in self.data)
824
+ if self.has_control:
825
+ missing_control_images = [
826
+ item["image_path"]
827
+ for item in self.data
828
+ if sum(f"control_path_{i}" not in item for i in range(self.control_count_per_image)) > 0
829
+ ]
830
+ if missing_control_images:
831
+ logger.error(f"Some images do not have control paths in JSONL data: {missing_control_images}")
832
+ raise ValueError(f"Some images do not have control paths in JSONL data: {missing_control_images}")
833
+ logger.info(f"found {len(self.data)} images with {self.control_count_per_image} control images per image in JSONL data")
834
+
835
  def is_indexable(self):
836
  return True
837
 
838
  def __len__(self):
839
  return len(self.data)
840
 
841
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[list[Image.Image]]]:
842
  data = self.data[idx]
843
  image_path = data["image_path"]
844
  image = Image.open(image_path).convert("RGB")
845
 
846
  caption = data["caption"]
847
 
848
+ controls = None
849
+ if self.has_control:
850
+ controls = []
851
+ for i in range(self.control_count_per_image):
852
+ control_path = data[f"control_path_{i}"]
853
+ control = Image.open(control_path)
854
+ if control.mode != "RGB" and control.mode != "RGBA":
855
+ control = control.convert("RGB")
856
+ controls.append(control)
857
+
858
+ return image_path, image, caption, controls
859
 
860
  def get_caption(self, idx: int) -> tuple[str, str]:
861
  data = self.data[idx]
 
899
 
900
  self.bucket_selector = None
901
 
902
+ self.source_fps = None
903
+ self.target_fps = None
904
+
905
  def __len__(self):
906
  raise NotImplementedError
907
 
 
918
  end_frame = end_frame if end_frame is not None else self.end_frame
919
  bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
920
 
921
+ video = load_video(
922
+ video_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
923
+ )
924
  return video
925
 
926
+ def get_control_data_from_path(
927
+ self,
928
+ control_path: str,
929
+ start_frame: Optional[int] = None,
930
+ end_frame: Optional[int] = None,
931
+ bucket_selector: Optional[BucketSelector] = None,
932
+ ) -> list[Image.Image]:
933
+ start_frame = start_frame if start_frame is not None else self.start_frame
934
+ end_frame = end_frame if end_frame is not None else self.end_frame
935
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
936
+
937
+ control = load_video(
938
+ control_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
939
+ )
940
+ return control
941
+
942
  def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
943
  self.start_frame = start_frame
944
  self.end_frame = end_frame
 
946
  def set_bucket_selector(self, bucket_selector: BucketSelector):
947
  self.bucket_selector = bucket_selector
948
 
949
+ def set_source_and_target_fps(self, source_fps: Optional[float], target_fps: Optional[float]):
950
+ self.source_fps = source_fps
951
+ self.target_fps = target_fps
952
+
953
  def __iter__(self):
954
  raise NotImplementedError
955
 
 
958
 
959
 
960
  class VideoDirectoryDatasource(VideoDatasource):
961
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None):
962
  super().__init__()
963
  self.video_directory = video_directory
964
  self.caption_extension = caption_extension
965
+ self.control_directory = control_directory # 新しく追加: コントロール画像ディレクトリ
966
  self.current_idx = 0
967
 
968
+ # glob videos
969
+ logger.info(f"glob videos in {self.video_directory}")
970
  self.video_paths = glob_videos(self.video_directory)
971
  logger.info(f"found {len(self.video_paths)} videos")
972
 
973
+ # glob control images if specified
974
+ if self.control_directory is not None:
975
+ logger.info(f"glob control videos in {self.control_directory}")
976
+ self.has_control = True
977
+ self.control_paths = {}
978
+ for video_path in self.video_paths:
979
+ video_basename = os.path.basename(video_path)
980
+ # construct control path from video path
981
+ # for example: video_path = "vid/video.mp4" -> control_path = "control/video.mp4"
982
+ control_path = os.path.join(self.control_directory, video_basename)
983
+ if os.path.exists(control_path):
984
+ self.control_paths[video_path] = control_path
985
+ else:
986
+ # use the same base name for control path
987
+ base_name = os.path.splitext(video_basename)[0]
988
+
989
+ # directory with images. for example: video_path = "vid/video.mp4" -> control_path = "control/video"
990
+ potential_path = os.path.join(self.control_directory, base_name) # no extension
991
+ if os.path.isdir(potential_path):
992
+ self.control_paths[video_path] = potential_path
993
+ else:
994
+ # another extension for control path
995
+ # for example: video_path = "vid/video.mp4" -> control_path = "control/video.mov"
996
+ for ext in VIDEO_EXTENSIONS:
997
+ potential_path = os.path.join(self.control_directory, base_name + ext)
998
+ if os.path.exists(potential_path):
999
+ self.control_paths[video_path] = potential_path
1000
+ break
1001
+
1002
+ logger.info(f"found {len(self.control_paths)} matching control videos/images")
1003
+ # check if all videos have matching control paths, if not, raise an error
1004
+ missing_controls = len(self.video_paths) - len(self.control_paths)
1005
+ if missing_controls > 0:
1006
+ # logger.warning(f"Could not find matching control videos/images for {missing_controls} videos")
1007
+ missing_controls_videos = [video_path for video_path in self.video_paths if video_path not in self.control_paths]
1008
+ logger.error(
1009
+ f"Could not find matching control videos/images for {missing_controls} videos: {missing_controls_videos}"
1010
+ )
1011
+ raise ValueError(f"Could not find matching control videos/images for {missing_controls} videos")
1012
+
1013
  def is_indexable(self):
1014
  return True
1015
 
 
1022
  start_frame: Optional[int] = None,
1023
  end_frame: Optional[int] = None,
1024
  bucket_selector: Optional[BucketSelector] = None,
1025
+ ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
1026
  video_path = self.video_paths[idx]
1027
  video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
1028
 
1029
  _, caption = self.get_caption(idx)
1030
 
1031
+ control = None
1032
+ if self.control_directory is not None and video_path in self.control_paths:
1033
+ control_path = self.control_paths[video_path]
1034
+ control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
1035
+
1036
+ return video_path, video, caption, control
1037
 
1038
  def get_caption(self, idx: int) -> tuple[str, str]:
1039
  video_path = self.video_paths[idx]
 
1083
  self.data.append(data)
1084
  logger.info(f"loaded {len(self.data)} videos")
1085
 
1086
+ # Check if there are control paths in the JSONL
1087
+ self.has_control = any("control_path" in item for item in self.data)
1088
+ if self.has_control:
1089
+ control_count = sum(1 for item in self.data if "control_path" in item)
1090
+ if control_count < len(self.data):
1091
+ missing_control_videos = [item["video_path"] for item in self.data if "control_path" not in item]
1092
+ logger.error(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
1093
+ raise ValueError(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
1094
+ logger.info(f"found {control_count} control videos/images in JSONL data")
1095
+
1096
  def is_indexable(self):
1097
  return True
1098
 
 
1105
  start_frame: Optional[int] = None,
1106
  end_frame: Optional[int] = None,
1107
  bucket_selector: Optional[BucketSelector] = None,
1108
+ ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
1109
  data = self.data[idx]
1110
  video_path = data["video_path"]
1111
  video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
1112
 
1113
  caption = data["caption"]
1114
 
1115
+ control = None
1116
+ if "control_path" in data and data["control_path"]:
1117
+ control_path = data["control_path"]
1118
+ control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
1119
+
1120
+ return video_path, video, caption, control
1121
 
1122
  def get_caption(self, idx: int) -> tuple[str, str]:
1123
  data = self.data[idx]
 
1324
  bucket_no_upscale: bool,
1325
  image_directory: Optional[str] = None,
1326
  image_jsonl_file: Optional[str] = None,
1327
+ control_directory: Optional[str] = None,
1328
  cache_directory: Optional[str] = None,
1329
+ fp_latent_window_size: Optional[int] = 9,
1330
+ fp_1f_clean_indices: Optional[list[int]] = None,
1331
+ fp_1f_target_index: Optional[int] = None,
1332
+ fp_1f_no_post: Optional[bool] = False,
1333
  debug_dataset: bool = False,
1334
  architecture: str = "no_default",
1335
  ):
 
1346
  )
1347
  self.image_directory = image_directory
1348
  self.image_jsonl_file = image_jsonl_file
1349
+ self.control_directory = control_directory
1350
+ self.fp_latent_window_size = fp_latent_window_size
1351
+ self.fp_1f_clean_indices = fp_1f_clean_indices
1352
+ self.fp_1f_target_index = fp_1f_target_index
1353
+ self.fp_1f_no_post = fp_1f_no_post
1354
+
1355
+ control_count_per_image = 1
1356
+ if fp_1f_clean_indices is not None:
1357
+ control_count_per_image = len(fp_1f_clean_indices)
1358
+
1359
  if image_directory is not None:
1360
+ self.datasource = ImageDirectoryDatasource(
1361
+ image_directory, caption_extension, control_directory, control_count_per_image
1362
+ )
1363
  elif image_jsonl_file is not None:
1364
+ self.datasource = ImageJsonlDatasource(image_jsonl_file, control_count_per_image)
1365
  else:
1366
  raise ValueError("image_directory or image_jsonl_file must be specified")
1367
 
 
1370
 
1371
  self.batch_manager = None
1372
  self.num_train_items = 0
1373
+ self.has_control = self.datasource.has_control
1374
 
1375
  def get_metadata(self):
1376
  metadata = super().get_metadata()
 
1378
  metadata["image_directory"] = os.path.basename(self.image_directory)
1379
  if self.image_jsonl_file is not None:
1380
  metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
1381
+ if self.control_directory is not None:
1382
+ metadata["control_directory"] = os.path.basename(self.control_directory)
1383
+ metadata["has_control"] = self.has_control
1384
  return metadata
1385
 
1386
  def get_total_image_count(self):
 
1405
  break # submit batch if possible
1406
 
1407
  for future in completed_futures:
1408
+ original_size, item_key, image, caption, controls = future.result()
1409
  bucket_height, bucket_width = image.shape[:2]
1410
  bucket_reso = (bucket_width, bucket_height)
1411
 
1412
  item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
1413
  item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1414
+ item_info.fp_latent_window_size = self.fp_latent_window_size
1415
+ item_info.fp_1f_clean_indices = self.fp_1f_clean_indices
1416
+ item_info.fp_1f_target_index = self.fp_1f_target_index
1417
+ item_info.fp_1f_no_post = self.fp_1f_no_post
1418
+
1419
+ if self.architecture == ARCHITECTURE_FRAMEPACK:
1420
+ # we need to split the bucket with latent window size and optional 1f clean indices, zero post
1421
+ bucket_reso = list(bucket_reso) + [self.fp_latent_window_size]
1422
+ if self.fp_1f_clean_indices is not None:
1423
+ bucket_reso.append(len(self.fp_1f_clean_indices))
1424
+ bucket_reso.append(self.fp_1f_no_post)
1425
+ bucket_reso = tuple(bucket_reso)
1426
+
1427
+ if controls is not None:
1428
+ item_info.control_content = controls
1429
 
1430
  if bucket_reso not in batches:
1431
  batches[bucket_reso] = []
 
1448
  for fetch_op in self.datasource:
1449
 
1450
  # fetch and resize image in a separate thread
1451
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str, Optional[Image.Image]]:
1452
+ image_key, image, caption, controls = op()
1453
  image: Image.Image
1454
  image_size = image.size
1455
 
1456
  bucket_reso = buckset_selector.get_bucket_resolution(image_size)
1457
+ image = resize_image_to_bucket(image, bucket_reso) # returns np.ndarray
1458
+ resized_controls = None
1459
+ if controls is not None:
1460
+ resized_controls = []
1461
+ for control in controls:
1462
+ resized_control = resize_image_to_bucket(control, bucket_reso) # returns np.ndarray
1463
+ resized_controls.append(resized_control)
1464
+
1465
+ return image_size, image_key, image, caption, resized_controls
1466
 
1467
  future = executor.submit(fetch_and_resize, fetch_op)
1468
  futures.append(future)
 
1507
  continue
1508
 
1509
  bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1510
+
1511
+ if self.architecture == ARCHITECTURE_FRAMEPACK:
1512
+ # we need to split the bucket with latent window size and optional 1f clean indices, zero post
1513
+ bucket_reso = list(bucket_reso) + [self.fp_latent_window_size]
1514
+ if self.fp_1f_clean_indices is not None:
1515
+ bucket_reso.append(len(self.fp_1f_clean_indices))
1516
+ bucket_reso.append(self.fp_1f_no_post)
1517
+ bucket_reso = tuple(bucket_reso)
1518
+
1519
  item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
1520
  item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1521
 
 
1545
 
1546
 
1547
  class VideoDataset(BaseDataset):
1548
+ TARGET_FPS_HUNYUAN = 24.0
1549
+ TARGET_FPS_WAN = 16.0
1550
+ TARGET_FPS_FRAMEPACK = 30.0
1551
+
1552
  def __init__(
1553
  self,
1554
  resolution: Tuple[int, int],
 
1561
  frame_stride: Optional[int] = 1,
1562
  frame_sample: Optional[int] = 1,
1563
  target_frames: Optional[list[int]] = None,
1564
+ max_frames: Optional[int] = None,
1565
+ source_fps: Optional[float] = None,
1566
  video_directory: Optional[str] = None,
1567
  video_jsonl_file: Optional[str] = None,
1568
+ control_directory: Optional[str] = None,
1569
  cache_directory: Optional[str] = None,
1570
+ fp_latent_window_size: Optional[int] = 9,
1571
  debug_dataset: bool = False,
1572
  architecture: str = "no_default",
1573
  ):
 
1584
  )
1585
  self.video_directory = video_directory
1586
  self.video_jsonl_file = video_jsonl_file
1587
+ self.control_directory = control_directory
1588
  self.frame_extraction = frame_extraction
1589
  self.frame_stride = frame_stride
1590
  self.frame_sample = frame_sample
1591
+ self.max_frames = max_frames
1592
+ self.source_fps = source_fps
1593
+ self.fp_latent_window_size = fp_latent_window_size
1594
+
1595
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
1596
+ self.target_fps = VideoDataset.TARGET_FPS_HUNYUAN
1597
+ elif self.architecture == ARCHITECTURE_WAN:
1598
+ self.target_fps = VideoDataset.TARGET_FPS_WAN
1599
+ elif self.architecture == ARCHITECTURE_FRAMEPACK:
1600
+ self.target_fps = VideoDataset.TARGET_FPS_FRAMEPACK
1601
+ else:
1602
+ raise ValueError(f"Unsupported architecture: {self.architecture}")
1603
+
1604
+ if target_frames is not None:
1605
+ target_frames = list(set(target_frames))
1606
+ target_frames.sort()
1607
+
1608
+ # round each value to N*4+1
1609
+ rounded_target_frames = [(f - 1) // 4 * 4 + 1 for f in target_frames]
1610
+ rouneded_target_frames = list(set(rounded_target_frames))
1611
+ rouneded_target_frames.sort()
1612
+
1613
+ # if value is changed, warn
1614
+ if target_frames != rounded_target_frames:
1615
+ logger.warning(f"target_frames are rounded to {rounded_target_frames}")
1616
+
1617
+ target_frames = tuple(rounded_target_frames)
1618
+
1619
+ self.target_frames = target_frames
1620
 
1621
  if video_directory is not None:
1622
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension, control_directory)
1623
  elif video_jsonl_file is not None:
1624
  self.datasource = VideoJsonlDatasource(video_jsonl_file)
1625
 
 
1635
 
1636
  self.batch_manager = None
1637
  self.num_train_items = 0
1638
+ self.has_control = self.datasource.has_control
1639
 
1640
  def get_metadata(self):
1641
  metadata = super().get_metadata()
 
1643
  metadata["video_directory"] = os.path.basename(self.video_directory)
1644
  if self.video_jsonl_file is not None:
1645
  metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1646
+ if self.control_directory is not None:
1647
+ metadata["control_directory"] = os.path.basename(self.control_directory)
1648
  metadata["frame_extraction"] = self.frame_extraction
1649
  metadata["frame_stride"] = self.frame_stride
1650
  metadata["frame_sample"] = self.frame_sample
1651
  metadata["target_frames"] = self.target_frames
1652
+ metadata["max_frames"] = self.max_frames
1653
+ metadata["source_fps"] = self.source_fps
1654
+ metadata["has_control"] = self.has_control
1655
  return metadata
1656
 
1657
  def retrieve_latent_cache_batches(self, num_workers: int):
1658
  buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
1659
  self.datasource.set_bucket_selector(buckset_selector)
1660
+ if self.source_fps is not None:
1661
+ self.datasource.set_source_and_target_fps(self.source_fps, self.target_fps)
1662
+ else:
1663
+ self.datasource.set_source_and_target_fps(None, None) # no conversion
1664
 
1665
  executor = ThreadPoolExecutor(max_workers=num_workers)
1666
 
1667
+ # key: (width, height, frame_count) and optional latent_window_size, value: [ItemInfo]
1668
+ batches: dict[tuple[Any], list[ItemInfo]] = {}
1669
  futures = []
1670
 
1671
  def aggregate_future(consume_all: bool = False):
 
1679
  break # submit batch if possible
1680
 
1681
  for future in completed_futures:
1682
+ original_frame_size, video_key, video, caption, control = future.result()
1683
 
1684
  frame_count = len(video)
1685
  video = np.stack(video, axis=0)
1686
  height, width = video.shape[1:3]
1687
  bucket_reso = (width, height) # already resized
1688
 
1689
+ # process control images if available
1690
+ control_video = None
1691
+ if control is not None:
1692
+ # set frame count to the same as video
1693
+ if len(control) > frame_count:
1694
+ control = control[:frame_count]
1695
+ elif len(control) < frame_count:
1696
+ # if control is shorter than video, repeat the last frame
1697
+ last_frame = control[-1]
1698
+ control.extend([last_frame] * (frame_count - len(control)))
1699
+ control_video = np.stack(control, axis=0)
1700
+
1701
  crop_pos_and_frames = []
1702
  if self.frame_extraction == "head":
1703
  for target_frame in self.target_frames:
 
1722
  frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1723
  for i in frame_indices:
1724
  crop_pos_and_frames.append((i, target_frame))
1725
+ elif self.frame_extraction == "full":
1726
+ # select all frames
1727
+ target_frame = min(frame_count, self.max_frames)
1728
+ target_frame = (target_frame - 1) // 4 * 4 + 1 # round to N*4+1
1729
+ crop_pos_and_frames.append((0, target_frame))
1730
  else:
1731
  raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1732
 
 
1736
  item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1737
  batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1738
 
1739
+ if self.architecture == ARCHITECTURE_FRAMEPACK:
1740
+ # add latent window size to bucket resolution
1741
+ batch_key = (*batch_key, self.fp_latent_window_size)
1742
+
1743
+ # crop control video if available
1744
+ cropped_control = None
1745
+ if control_video is not None:
1746
+ cropped_control = control_video[crop_pos : crop_pos + target_frame]
1747
+
1748
  item_info = ItemInfo(
1749
  item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1750
  )
1751
  item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1752
+ item_info.control_content = cropped_control # None is allowed
1753
+ item_info.fp_latent_window_size = self.fp_latent_window_size
1754
 
1755
  batch = batches.get(batch_key, [])
1756
  batch.append(item_info)
 
1771
 
1772
  for operator in self.datasource:
1773
 
1774
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str, Optional[list[np.ndarray]]]:
1775
+ result = op()
1776
+
1777
+ if len(result) == 3: # for backward compatibility TODO remove this in the future
1778
+ video_key, video, caption = result
1779
+ control = None
1780
+ else:
1781
+ video_key, video, caption, control = result
1782
+
1783
  video: list[np.ndarray]
1784
  frame_size = (video[0].shape[1], video[0].shape[0])
1785
 
 
1787
  bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1788
  video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1789
 
1790
+ # resize control if necessary
1791
+ if control is not None:
1792
+ control = [resize_image_to_bucket(frame, bucket_reso) for frame in control]
1793
+
1794
+ return frame_size, video_key, video, caption, control
1795
 
1796
  future = executor.submit(fetch_and_resize, operator)
1797
  futures.append(future)
 
1829
  image_width, image_height = map(int, image_size.split("x"))
1830
  image_size = (image_width, image_height)
1831
 
1832
+ frame_pos, frame_count = tokens[-3].split("-")[:2] # "00000-000", or optional section index "00000-000-00"
1833
  frame_pos, frame_count = int(frame_pos), int(frame_count)
1834
 
1835
  item_key = "_".join(tokens[:-3])
docs/advanced_config.md CHANGED
@@ -2,6 +2,16 @@
2
 
3
  # Advanced configuration / 高度な設定
4
 
 
 
 
 
 
 
 
 
 
 
5
  ## How to specify `network_args` / `network_args`の指定方法
6
 
7
  The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
@@ -148,4 +158,159 @@ Specify the project name with `--log_tracker_name` when using wandb.
148
  `--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
149
 
150
  wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
151
- </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # Advanced configuration / 高度な設定
4
 
5
+ ## Table of contents / 目次
6
+
7
+ - [How to specify `network_args`](#how-to-specify-network_args--network_argsの指定方法)
8
+ - [LoRA+](#lora)
9
+ - [Select the target modules of LoRA](#select-the-target-modules-of-lora--loraの対象モジュールを選択する)
10
+ - [Save and view logs in TensorBoard format](#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)
11
+ - [Save and view logs in wandb](#save-and-view-logs-in-wandb--wandbでログの保存と参照)
12
+ - [FP8 weight optimization for models](#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化)
13
+ - [PyTorch Dynamo optimization for model training](#pytorch-dynamo-optimization-for-model-training--モデルの学習におけるpytorch-dynamoの最適化)
14
+
15
  ## How to specify `network_args` / `network_args`の指定方法
16
 
17
  The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
 
158
  `--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
159
 
160
  wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
161
+ </details>
162
+
163
+ ## FP8 weight optimization for models / モデルの重みのFP8への最適化
164
+
165
+ The `--fp8_scaled` option is available to quantize the weights of the model to FP8 (E4M3) format with appropriate scaling. This reduces the VRAM usage while maintaining precision. Important weights are kept in FP16/BF16/FP32 format.
166
+
167
+ The model weights must be in fp16 or bf16. Weights that have been pre-converted to float8_e4m3 cannot be used.
168
+
169
+ Wan2.1 inference and training are supported.
170
+
171
+ Specify the `--fp8_scaled` option in addition to the `--fp8` option during inference.
172
+
173
+ Specify the `--fp8_scaled` option in addition to the `--fp8_base` option during training.
174
+
175
+ Acknowledgments: This feature is based on the [implementation](https://github.com/Tencent/HunyuanVideo/blob/7df4a45c7e424a3f6cd7d653a7ff1f60cddc1eb1/hyvideo/modules/fp8_optimization.py) of [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). The selection of high-precision modules is based on the [implementation](https://github.com/tdrussell/diffusion-pipe/blob/407c04fdae1c9ab5e67b54d33bef62c3e0a8dbc7/models/wan.py) of [diffusion-pipe](https://github.com/tdrussell/diffusion-pipe). I would like to thank these repositories.
176
+
177
+ <details>
178
+ <summary>日本語</summary>
179
+ 重みを単純にFP8へcastするのではなく、適切なスケーリングでFP8形式に量子化することで、精度を維持しつつVRAM使用量を削減します。また、重要な重みはFP16/BF16/FP32形式で保持します。
180
+
181
+ モデルの重みは、fp16またはbf16が必要です。あらかじめfloat8_e4m3に変換された重みは使用できません。
182
+
183
+ Wan2.1の推論、学習のみ対応しています。
184
+
185
+ 推論時は`--fp8`オプションに加えて `--fp8_scaled`オプションを指定してください。
186
+
187
+ 学習時は`--fp8_base`オプションに加えて `--fp8_scaled`オプションを指定してください。
188
+
189
+ 謝辞:この機能は、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)の[実装](https://github.com/Tencent/HunyuanVideo/blob/7df4a45c7e424a3f6cd7d653a7ff1f60cddc1eb1/hyvideo/modules/fp8_optimization.py)を参考にしました。また、高精度モジュールの選択においては[diffusion-pipe](https://github.com/tdrussell/diffusion-pipe)の[実装](https://github.com/tdrussell/diffusion-pipe/blob/407c04fdae1c9ab5e67b54d33bef62c3e0a8dbc7/models/wan.py)を参考にしました。これらのリポジトリに感謝します。
190
+
191
+ </details>
192
+
193
+ ### Key features and implementation details / 主な特徴と実装の詳細
194
+
195
+ - Implements FP8 (E4M3) weight quantization for Linear layers
196
+ - Reduces VRAM requirements by using 8-bit weights for storage (slightly increased compared to existing `--fp8` `--fp8_base` options)
197
+ - Quantizes weights to FP8 format with appropriate scaling instead of simple cast to FP8
198
+ - Maintains computational precision by dequantizing to original precision (FP16/BF16/FP32) during forward pass
199
+ - Preserves important weights in FP16/BF16/FP32 format
200
+
201
+ The implementation:
202
+
203
+ 1. Quantizes weights to FP8 format with appropriate scaling
204
+ 2. Replaces weights by FP8 quantized weights and stores scale factors in model state dict
205
+ 3. Applies monkey patching to Linear layers for transparent dequantization during computation
206
+
207
+ <details>
208
+ <summary>日本語</summary>
209
+
210
+ - Linear層のFP8(E4M3)重み量子化を実装
211
+ - 8ビットの重みを使用することでVRAM使用量を削減(既存の`--fp8` `--fp8_base` オプションに比べて微増)
212
+ - 単純なFP8へのcastではなく、適切な値でスケールして重���をFP8形式に量子化
213
+ - forward時に元の精度(FP16/BF16/FP32)に逆量子化して計算精度を維持
214
+ - 精度が重要な重みはFP16/BF16/FP32のまま保持
215
+
216
+ 実装:
217
+
218
+ 1. 精度を維持できる適切な倍率で重みをFP8形式に量子化
219
+ 2. 重みをFP8量子化重みに置き換え、倍率をモデルのstate dictに保存
220
+ 3. Linear層にmonkey patchingすることでモデルを変更せずに逆量子化
221
+ </details>
222
+
223
+ ## PyTorch Dynamo optimization for model training / モデルの学習におけるPyTorch Dynamoの最適化
224
+
225
+ The PyTorch Dynamo options are now available to optimize the training process. PyTorch Dynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster by using TorchInductor, a deep learning compiler. This integration allows for potential speedups in training while maintaining model accuracy.
226
+
227
+ [PR #215](https://github.com/kohya-ss/musubi-tuner/pull/215) added this feature.
228
+
229
+ Specify the `--dynamo_backend` option to enable Dynamo optimization with one of the available backends from the `DynamoBackend` enum.
230
+
231
+ Additional options allow for fine-tuning the Dynamo behavior:
232
+ - `--dynamo_mode`: Controls the optimization strategy
233
+ - `--dynamo_fullgraph`: Enables fullgraph mode for potentially better optimization
234
+ - `--dynamo_dynamic`: Enables dynamic shape handling
235
+
236
+ The `--dynamo_dynamic` option has been reported to have many problems based on the validation in PR #215.
237
+
238
+ ### Available options:
239
+
240
+ ```
241
+ --dynamo_backend {NO, INDUCTOR, NVFUSER, CUDAGRAPHS, CUDAGRAPHS_FALLBACK, etc.}
242
+ Specifies the Dynamo backend to use (default is NO, which disables Dynamo)
243
+
244
+ --dynamo_mode {default, reduce-overhead, max-autotune}
245
+ Specifies the optimization mode (default is 'default')
246
+ - 'default': Standard optimization
247
+ - 'reduce-overhead': Focuses on reducing compilation overhead
248
+ - 'max-autotune': Performs extensive autotuning for potentially better performance
249
+
250
+ --dynamo_fullgraph
251
+ Flag to enable fullgraph mode, which attempts to capture and optimize the entire model graph
252
+
253
+ --dynamo_dynamic
254
+ Flag to enable dynamic shape handling for models with variable input shapes
255
+ ```
256
+
257
+ ### Usage example:
258
+
259
+ ```bash
260
+ python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode default
261
+ ```
262
+
263
+ For more aggressive optimization:
264
+ ```bash
265
+ python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode max-autotune --dynamo_fullgraph
266
+ ```
267
+
268
+ Note: The best combination of options may depend on your specific model and hardware. Experimentation may be necessary to find the optimal configuration.
269
+
270
+ <details>
271
+ <summary>日本語</summary>
272
+ PyTorch Dynamoオプションが学習プロセスを最適化するために追加されました。PyTorch Dynamoは、TorchInductor(ディープラーニングコンパイラ)を使用して、変更を加えることなくPyTorchプログラムを高速化するためのPythonレベルのJITコンパイラです。この統合により、モデルの精度を維持しながら学習の高速化が期待できます。
273
+
274
+ [PR #215](https://github.com/kohya-ss/musubi-tuner/pull/215) で追加されました。
275
+
276
+ `--dynamo_backend`オプションを指定して、`DynamoBackend`列挙型から利用可能なバックエンドの一つを選択することで、Dynamo最適化を有効にします。
277
+
278
+ 追加のオプションにより、Dynamoの動作を微調整できます:
279
+ - `--dynamo_mode`:最適化戦略を制御します
280
+ - `--dynamo_fullgraph`:より良い最適化の可能性のためにフルグラフモードを有効にします
281
+ - `--dynamo_dynamic`:動的形状処理を有効にします
282
+
283
+ PR #215での検証によると、`--dynamo_dynamic`には問題が多いことが報告されています。
284
+
285
+ __利用可能なオプション:__
286
+
287
+ ```
288
+ --dynamo_backend {NO, INDUCTOR, NVFUSER, CUDAGRAPHS, CUDAGRAPHS_FALLBACK, など}
289
+ 使用するDynamoバックエンドを指定します(デフォルトはNOで、Dynamoを無効にします)
290
+
291
+ --dynamo_mode {default, reduce-overhead, max-autotune}
292
+ 最適化モードを指定します(デフォルトは 'default')
293
+ - 'default':標準的な最適化
294
+ - 'reduce-overhead':コンパイルのオーバーヘッド削減に焦点を当てる
295
+ - 'max-autotune':より良いパフォーマンスのために広範な自動調整を実行
296
+
297
+ --dynamo_fullgraph
298
+ フルグラフモードを有効にするフラグ。モデルグラフ全体をキャプチャして最適化しようとします
299
+
300
+ --dynamo_dynamic
301
+ 可変入力形状を持つモデルのための動的形状処理を有効にするフラグ
302
+ ```
303
+
304
+ __使用例:__
305
+
306
+ ```bash
307
+ python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode default
308
+ ```
309
+
310
+ より積極的な最適化の場合:
311
+ ```bash
312
+ python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode max-autotune --dynamo_fullgraph
313
+ ```
314
+
315
+ 注意:最適なオプションの組み合わせは、特定のモデルとハードウェアに依存する場���があります。最適な構成を見つけるために実験が必要かもしれません。
316
+ </details>
docs/framepack.md ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FramePack
2
+
3
+ ## Overview / 概要
4
+
5
+ This document describes the usage of the [FramePack](https://github.com/lllyasviel/FramePack) architecture within the Musubi Tuner framework. FramePack is a novel video generation architecture developed by lllyasviel.
6
+
7
+ Key differences from HunyuanVideo:
8
+ - FramePack only supports Image-to-Video (I2V) generation. Text-to-Video (T2V) is not supported.
9
+ - It utilizes a different DiT model architecture and requires an additional Image Encoder. VAE is same as HunyuanVideo. Text Encoders seem to be the same as HunyuanVideo but we employ the original FramePack method to utilize them.
10
+ - Caching and training scripts are specific to FramePack (`fpack_*.py`).
11
+ - Due to its progressive generation nature, VRAM usage can be significantly lower, especially for longer videos, compared to other architectures.
12
+
13
+ The official documentation does not provide detailed explanations on how to train the model, but it is based on the FramePack implementation and paper.
14
+
15
+ This feature is experimental.
16
+
17
+ For one-frame inference and training, see [here](./framepack_1f.md).
18
+
19
+ <details>
20
+ <summary>日本語</summary>
21
+
22
+ このドキュメントは、Musubi Tunerフレームワーク内での[FramePack](https://github.com/lllyasviel/FramePack) アーキテクチャの使用法について説明しています。FramePackは、lllyasviel氏にによって開発された新しいビデオ生成アーキテクチャです。
23
+
24
+ HunyuanVideoとの主な違いは次のとおりです。
25
+ - FramePackは、画像からビデオ(I2V)生成のみをサポートしています。テキストからビデオ(T2V)はサポートされていません。
26
+ - 異なるDiTモデルアーキテクチャを使用し、追加の画像エンコーダーが必要です。VAEはHunyuanVideoと同じです。テキストエンコーダーはHunyuanVideoと同じと思われますが、FramePack公式と同じ方法で推論を行っています。
27
+ - キャッシングと学習スクリプトはFramePack専用(`fpack_*.py`)です。
28
+ - セクションずつ生成するため、他のアーキテクチャと比較して、特に長いビデオの場合、VRAM使用量が大幅に少なくなる可能性があります。
29
+
30
+ 学習方法について公式からは詳細な説明はありませんが、FramePackの実装と論文を参考にしています。
31
+
32
+ この機能は実験的なものです。
33
+
34
+ 1フレーム推論、学習については[こちら](./framepack_1f.md)を参照してください。
35
+ </details>
36
+
37
+ ## Download the model / モデルのダウンロード
38
+
39
+ You need to download the DiT, VAE, Text Encoder 1 (LLaMA), Text Encoder 2 (CLIP), and Image Encoder (SigLIP) models specifically for FramePack. Several download options are available for each component.
40
+
41
+ ***Note:** The weights are publicly available on the following page: [maybleMyers/framepack_h1111](https://huggingface.co/maybleMyers/framepack_h1111) (except for FramePack-F1). Thank you maybleMyers!
42
+
43
+ ### DiT Model
44
+
45
+ Choose one of the following methods:
46
+
47
+ 1. **From lllyasviel's Hugging Face repo:** Download the three `.safetensors` files (starting with `diffusion_pytorch_model-00001-of-00003.safetensors`) from [lllyasviel/FramePackI2V_HY](https://huggingface.co/lllyasviel/FramePackI2V_HY). Specify the path to the first file (`...-00001-of-00003.safetensors`) as the `--dit` argument. For FramePack-F1, download from [lllyasviel/FramePack_F1_I2V_HY_20250503](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503).
48
+
49
+ 2. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the model might be downloaded locally. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--lllyasviel--FramePackI2V_HY/snapshots/<hex-uuid-folder>`. FramePack-F1 is also available in the same way.
50
+
51
+ 3. **From Kijai's Hugging Face repo:** Download the single file `FramePackI2V_HY_bf16.safetensors` from [Kijai/HunyuanVideo_comfy](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors). Specify the path to this file as the `--dit` argument. No FramePack-F1 model is available here currently.
52
+
53
+ ### VAE Model
54
+
55
+ Choose one of the following methods:
56
+
57
+ 1. **Use official HunyuanVideo VAE:** Follow the instructions in the main [README.md](../README.md#model-download).
58
+ 2. **From hunyuanvideo-community Hugging Face repo:** Download `vae/diffusion_pytorch_model.safetensors` from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo).
59
+ 3. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the VAE might be downloaded locally within the HunyuanVideo community model snapshot. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/<hex-uuid-folder>`.
60
+
61
+ ### Text Encoder 1 (LLaMA) Model
62
+
63
+ Choose one of the following methods:
64
+
65
+ 1. **From Comfy-Org Hugging Face repo:** Download `split_files/text_encoders/llava_llama3_fp16.safetensors` from [Comfy-Org/HunyuanVideo_repackaged](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged).
66
+ 2. **From hunyuanvideo-community Hugging Face repo:** Download the four `.safetensors` files (starting with `text_encoder/model-00001-of-00004.safetensors`) from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo). Specify the path to the first file (`...-00001-of-00004.safetensors`) as the `--text_encoder1` argument.
67
+ 3. **From local FramePack installation:** (Same as VAE) Specify the path to the HunyuanVideo community model snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/<hex-uuid-folder>`.
68
+
69
+ ### Text Encoder 2 (CLIP) Model
70
+
71
+ Choose one of the following methods:
72
+
73
+ 1. **From Comfy-Org Hugging Face repo:** Download `split_files/text_encoders/clip_l.safetensors` from [Comfy-Org/HunyuanVideo_repackaged](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged).
74
+ 2. **From hunyuanvideo-community Hugging Face repo:** Download `text_encoder_2/model.safetensors` from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo).
75
+ 3. **From local FramePack installation:** (Same as VAE) Specify the path to the HunyuanVideo community model snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/<hex-uuid-folder>`.
76
+
77
+ ### Image Encoder (SigLIP) Model
78
+
79
+ Choose one of the following methods:
80
+
81
+ 1. **From Comfy-Org Hugging Face repo:** Download `sigclip_vision_patch14_384.safetensors` from [Comfy-Org/sigclip_vision_384](https://huggingface.co/Comfy-Org/sigclip_vision_384).
82
+ 2. **From lllyasviel's Hugging Face repo:** Download `image_encoder/model.safetensors` from [lllyasviel/flux_redux_bfl](https://huggingface.co/lllyasviel/flux_redux_bfl).
83
+ 3. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the model might be downloaded locally. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--lllyasviel--flux_redux_bfl/snapshots/<hex-uuid-folder>`.
84
+
85
+ <details>
86
+ <summary>日本語</summary>
87
+
88
+ ※以下のページに重みが一括で公開されています(FramePack-F1を除く)。maybleMyers 氏に感謝いたします。: https://huggingface.co/maybleMyers/framepack_h1111
89
+
90
+ DiT、VAE、テキストエンコーダー1(LLaMA)、テキストエンコーダー2(CLIP)、および画像エンコーダー(SigLIP)モデルは複数の方法でダウンロードできます。英語の説明を参考にして、ダウンロードしてください。
91
+
92
+ FramePack公式のリポジトリをクローンして実行した場合、モデルはローカルにダウンロードされている可能性があります。スナップショットディレクトリへのパスを指定してください。例:`path/to/FramePack/hf_download/hub/models--lllyasviel--flux_redux_bfl/snapshots/<hex-uuid-folder>`
93
+
94
+ HunyuanVideoの推論をComfyUIですでに行っている場合、いくつかのモデルはすでにダウンロードされている可能性があります。
95
+ </details>
96
+
97
+ ## Pre-caching / 事前キャッシング
98
+
99
+ The default resolution for FramePack is 640x640. See [the source code](../frame_pack/bucket_tools.py) for the default resolution of each bucket.
100
+
101
+ The dataset for training must be a video dataset. Image datasets are not supported. You can train on videos of any length. Specify `frame_extraction` as `full` and set `max_frames` to a sufficiently large value. However, if the video is too long, you may run out of VRAM during VAE encoding.
102
+
103
+ ### Latent Pre-caching / latentの事前キャッシング
104
+
105
+ Latent pre-caching uses a dedicated script for FramePack. You **must** provide the Image Encoder model.
106
+
107
+ ```bash
108
+ python fpack_cache_latents.py \
109
+ --dataset_config path/to/toml \
110
+ --vae path/to/vae_model.safetensors \
111
+ --image_encoder path/to/image_encoder_model.safetensors \
112
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
113
+ ```
114
+
115
+ Key differences from HunyuanVideo caching:
116
+ - Uses `fpack_cache_latents.py`.
117
+ - Requires the `--image_encoder` argument pointing to the downloaded SigLIP model.
118
+ - The script generates multiple cache files per video, each corresponding to a different section, with the section index appended to the filename (e.g., `..._frame_pos-0000-count_...` becomes `..._frame_pos-0000-0000-count_...`, `..._frame_pos-0000-0001-count_...`, etc.).
119
+ - Image embeddings are calculated using the Image Encoder and stored in the cache files alongside the latents.
120
+
121
+ For VRAM savings during VAE decoding, consider using `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size`. If VRAM is overflowing and using shared memory, it is recommended to set `--vae_chunk_size` to 16 or 8, and `--vae_spatial_tile_sample_min_size` to 64 or 32.
122
+
123
+ Specifying `--f1` is required for FramePack-F1 training. For one-frame training, specify `--one_frame`. If you change the presence of these options, please overwrite the existing cache without specifying `--skip_existing`.
124
+
125
+ `--one_frame_no_2x` and `--one_frame_no_4x` options are available for one-frame training, described in the next section.
126
+
127
+ **FramePack-F1 support:**
128
+ You can apply the FramePack-F1 sampling method by specifying `--f1` during caching. The training script also requires specifying `--f1` to change the options during sample generation.
129
+
130
+ By default, the sampling method used is Inverted anti-drifting (the same as during inference with the original FramePack model, using the latent and index in reverse order), described in the paper. You can switch to FramePack-F1 sampling (Vanilla sampling, using the temporally ordered latent and index) by specifying `--f1`.
131
+
132
+ <details>
133
+ <summary>日本語</summary>
134
+
135
+ FramePackのデフォルト解像度は640x640です。各バケットのデフォルト解像度については、[ソースコード](../frame_pack/bucket_tools.py)を参照してください。
136
+
137
+ 画像データセットでの学習は行えません。また動画の長さによらず学習可能です。 `frame_extraction` に `full` を指定して、`max_frames` に十分に大きな値を指定してください。ただし、あまりにも長いとVAEのencodeでVRAMが不足する可能性があります。
138
+
139
+ latentの事前キャッシングはFramePack専用のスクリプトを使用します。画像エンコーダーモデルを指定する必要があります。
140
+
141
+ HunyuanVideoのキャッシングとの主な違いは次のとおりです。
142
+ - `fpack_cache_latents.py`を使用します。
143
+ - ダウンロードしたSigLIPモデルを指す`--image_encoder`引数が必要です。
144
+ - スクリプトは、各ビデオに対して複数のキャッシュファイルを生成します。各ファイルは異なるセクションに対応し、セクションインデックスがファイル名に追加されます(例:`..._frame_pos-0000-count_...`は`..._frame_pos-0000-0000-count_...`、`..._frame_pos-0000-0001-count_...`などになります)。
145
+ - 画像埋め込みは画像エンコーダーを使用して計算され、latentとともにキャッシュファイルに保存されます。
146
+
147
+ VAEのdecode時のVRAM節約のために、`--vae_chunk_size`と`--vae_spatial_tile_sample_min_size`を使用することを検討してください。VRAMがあふれて共有メモリを使用している場合には、`--vae_chunk_size`を16、8などに、`--vae_spatial_tile_sample_min_size`を64、32などに変更することをお勧めします。
148
+
149
+ FramePack-F1の学習を行う場合は`--f1`を指定してください。これらのオプションの有無を変更する場合には、`--skip_existing`を指定せずに既存のキャッシュを上書きしてください。
150
+
151
+ **FramePack-F1のサポート:**
152
+ キャッシュ時のオプションに`--f1`を指定することで、FramePack-F1のサンプリング方法を適用できます。学習スクリプトについても`--f1`を指定してサンプル生成時のオプションを変更する必要があります。
153
+
154
+ デフォルトでは、論文のサンプリング方法 Inverted anti-drifting (無印のFramePackの推論時と同じ、逆順の latent と index を使用)を使用します。`--f1`を指定すると FramePack-F1 の Vanilla sampling (時間順の latent と index を使用)に変更できます。
155
+ </details>
156
+
157
+ ### Text Encoder Output Pre-caching / テキストエンコーダー出力の事前キャッシング
158
+
159
+ Text encoder output pre-caching also uses a dedicated script.
160
+
161
+ ```bash
162
+ python fpack_cache_text_encoder_outputs.py \
163
+ --dataset_config path/to/toml \
164
+ --text_encoder1 path/to/text_encoder1 \
165
+ --text_encoder2 path/to/text_encoder2 \
166
+ --batch_size 16
167
+ ```
168
+
169
+ Key differences from HunyuanVideo caching:
170
+ - Uses `fpack_cache_text_encoder_outputs.py`.
171
+ - Requires both `--text_encoder1` (LLaMA) and `--text_encoder2` (CLIP) arguments.
172
+ - Uses `--fp8_llm` option to run the LLaMA Text Encoder 1 in fp8 mode for VRAM savings (similar to `--fp8_t5` in Wan2.1).
173
+ - Saves LLaMA embeddings, attention mask, and CLIP pooler output to the cache file.
174
+
175
+ <details>
176
+ <summary>日本語</summary>
177
+
178
+ テキストエンコーダー出力の事前キャッシングも専用のスクリプトを使用します。
179
+
180
+ HunyuanVideoのキャッシングとの主な違いは次のとおりです。
181
+ - `fpack_cache_text_encoder_outputs.py`を使用します。
182
+ - LLaMAとCLIPの両方の引数が必要です。
183
+ - LLaMAテキストエンコーダー1をfp8モードで実行するための`--fp8_llm`オプションを使用します(Wan2.1の`--fp8_t5`に似ています)。
184
+ - LLaMAの埋め込み、アテンションマスク、CLIPのプーラー出力をキャッシュファイルに保存します。
185
+
186
+ </details>
187
+
188
+
189
+ ## Training / 学習
190
+
191
+ ### Training
192
+
193
+ Training uses a dedicated script `fpack_train_network.py`. Remember FramePack only supports I2V training.
194
+
195
+ ```bash
196
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 fpack_train_network.py \
197
+ --dit path/to/dit_model \
198
+ --vae path/to/vae_model.safetensors \
199
+ --text_encoder1 path/to/text_encoder1 \
200
+ --text_encoder2 path/to/text_encoder2 \
201
+ --image_encoder path/to/image_encoder_model.safetensors \
202
+ --dataset_config path/to/toml \
203
+ --sdpa --mixed_precision bf16 \
204
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
205
+ --timestep_sampling shift --weighting_scheme none --discrete_flow_shift 3.0 \
206
+ --max_data_loader_n_workers 2 --persistent_data_loader_workers \
207
+ --network_module networks.lora_framepack --network_dim 32 \
208
+ --max_train_epochs 16 --save_every_n_epochs 1 --seed 42 \
209
+ --output_dir path/to/output_dir --output_name name-of-lora
210
+ ```
211
+
212
+ If you use the command prompt (Windows, not PowerShell), you may need to write them in a single line, or use `^` instead of `\` at the end of each line to continue the command.
213
+
214
+ The maximum value for `--blocks_to_swap` is 36. The default resolution for FramePack is 640x640, which requires around 17GB of VRAM. If you run out of VRAM, consider lowering the dataset resolution.
215
+
216
+ Key differences from HunyuanVideo training:
217
+ - Uses `fpack_train_network.py`.
218
+ - `--f1` option is available for FramePack-F1 model training. You need to specify the FramePack-F1 model as `--dit`. This option only changes the sample generation during training. The training process itself is the same as the original FramePack model.
219
+ - **Requires** specifying `--vae`, `--text_encoder1`, `--text_encoder2`, and `--image_encoder`.
220
+ - **Requires** specifying `--network_module networks.lora_framepack`.
221
+ - Optional `--latent_window_size` argument (default 9, should match caching).
222
+ - Memory saving options like `--fp8` (for DiT) and `--fp8_llm` (for Text Encoder 1) are available. `--fp8_scaled` is recommended when using `--fp8` for DiT.
223
+ - `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size` options are available for the VAE to prevent out-of-memory during sampling (similar to caching).
224
+ - `--gradient_checkpointing` is available for memory savings.
225
+ - If you encounter an error when the batch size is greater than 1 (especially when specifying `--sdpa` or `--xformers`, it will always result in an error), please specify `--split_attn`.
226
+ <!-- - Use `convert_lora.py` for converting the LoRA weights after training, similar to HunyuanVideo. -->
227
+
228
+ Training settings (learning rate, optimizers, etc.) are experimental. Feedback is welcome.
229
+
230
+ <details>
231
+ <summary>日本語</summary>
232
+
233
+ FramePackの学習は専用のスクリプト`fpack_train_network.py`を使用します。FramePackはI2V学習のみをサポートしています。
234
+
235
+ コマンド記述例は英語版を参考にしてください。WindowsでPowerShellではなくコマンドプロンプトを使用している場合、コマンドを1行で記述するか、各行の末尾に`\`の代わりに`^`を付けてコマンドを続ける必要があります。
236
+
237
+ `--blocks_to_swap`の最大値は36です。FramePackのデフォルト解像度(640x640)では、17GB程度のVRAMが必要です。VRAM容量が不足する場合は、データセットの解像度を下げてください。
238
+
239
+ HunyuanVideoの学習との主な違いは次のとおりです。
240
+ - `fpack_train_network.py`を使用します。
241
+ - FramePack-F1モデルの学習時には`--f1`を指定してください。この場合、`--dit`にFramePack-F1モデルを指定する必要があります。このオプションは学習時のサンプル生成時のみに影響し、学習プロセス自体は元のFramePackモデルと同じです。
242
+ - `--vae`、`--text_encoder1`、`--text_encoder2`、`--image_encoder`を指定する必要があります。
243
+ - `--network_module networks.lora_framepack`を指定する必要があります。
244
+ - 必要に応じて`--latent_window_size`引数(デフォルト9)を指定できます(キャッシング時と一致させる必要があります)。
245
+ - `--fp8`(DiT用)や`--fp8_llm`(テキストエンコーダー1用)などのメモリ節約オプションが利用可能です。`--fp8_scaled`を使用することをお勧めします。
246
+ - サンプル生成時にメモリ不足を防ぐため、VAE用の`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`オプションが利用可能です(キャッシング時と同様)。
247
+ - メモリ節約のために`--gradient_checkpointing`が利用可能です。
248
+ - バッチサイズが1より大きい場合にエラーが出た時には(特に`--sdpa`や`--xformers`を指定すると必ずエラーになります。)、`--split_attn`を指定してください。
249
+
250
+ </details>
251
+
252
+ ## Inference
253
+
254
+ Inference uses a dedicated script `fpack_generate_video.py`.
255
+
256
+ ```bash
257
+ python fpack_generate_video.py \
258
+ --dit path/to/dit_model \
259
+ --vae path/to/vae_model.safetensors \
260
+ --text_encoder1 path/to/text_encoder1 \
261
+ --text_encoder2 path/to/text_encoder2 \
262
+ --image_encoder path/to/image_encoder_model.safetensors \
263
+ --image_path path/to/start_image.jpg \
264
+ --prompt "A cat walks on the grass, realistic style." \
265
+ --video_size 512 768 --video_seconds 5 --fps 30 --infer_steps 25 \
266
+ --attn_mode sdpa --fp8_scaled \
267
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
268
+ --save_path path/to/save/dir --output_type both \
269
+ --seed 1234 --lora_multiplier 1.0 --lora_weight path/to/lora.safetensors
270
+ ```
271
+ <!-- --embedded_cfg_scale 10.0 --guidance_scale 1.0 \ -->
272
+
273
+ Key differences from HunyuanVideo inference:
274
+ - Uses `fpack_generate_video.py`.
275
+ - `--f1` option is available for FramePack-F1 model inference (forward generation). You need to specify the FramePack-F1 model as `--dit`.
276
+ - **Requires** specifying `--vae`, `--text_encoder1`, `--text_encoder2`, and `--image_encoder`.
277
+ - **Requires** specifying `--image_path` for the starting frame.
278
+ - **Requires** specifying `--video_seconds` or `--video_sections`. `--video_seconds` specifies the length of the video in seconds, while `--video_sections` specifies the number of sections. If `--video_sections` is specified, `--video_seconds` is ignored.
279
+ - `--video_size` is the size of the generated video, height and width are specified in that order.
280
+ - `--prompt`: Prompt for generation.
281
+ - Optional `--latent_window_size` argument (default 9, should match caching and training).
282
+ - `--fp8_scaled` option is available for DiT to reduce memory usage. Quality may be slightly lower. `--fp8_llm` option is available to reduce memory usage of Text Encoder 1. `--fp8` alone is also an option for DiT but `--fp8_scaled` potentially offers better quality.
283
+ - LoRA loading options (`--lora_weight`, `--lora_multiplier`, `--include_patterns`, `--exclude_patterns`) are available. `--lycoris` is also supported.
284
+ - `--embedded_cfg_scale` (default 10.0) controls the distilled guidance scale.
285
+ - `--guidance_scale` (default 1.0) controls the standard classifier-free guidance scale. **Changing this from 1.0 is generally not recommended for the base FramePack model.**
286
+ - `--guidance_rescale` (default 0.0) is available but typically not needed.
287
+ - `--bulk_decode` option can decode all frames at once, potentially faster but uses more VRAM during decoding. `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size` options are recommended to prevent out-of-memory errors.
288
+ - `--sample_solver` (default `unipc`) is available but only `unipc` is implemented.
289
+ - `--save_merged_model` option is available to save the DiT model after merging LoRA weights. Inference is skipped if this is specified.
290
+ - `--latent_paddings` option overrides the default padding for each section. Specify it as a comma-separated list of integers, e.g., `--latent_paddings 0,0,0,0`. This option is ignored if `--f1` is specified.
291
+ - `--custom_system_prompt` option overrides the default system prompt for the LLaMA Text Encoder 1. Specify it as a string. See [here](../hunyuan_model/text_encoder.py#L152) for the default system prompt.
292
+ - `--rope_scaling_timestep_threshold` option is the RoPE scaling timestep threshold, default is None (disabled). If set, RoPE scaling is applied only when the timestep exceeds the threshold. Start with around 800 and adjust as needed. This option is intended for one-frame inference and may not be suitable for other cases.
293
+ - `--rope_scaling_factor` option is the RoPE scaling factor, default is 0.5, assuming a resolution of 2x. For 1.5x resolution, around 0.7 is recommended.
294
+
295
+ Other options like `--video_size`, `--fps`, `--infer_steps`, `--save_path`, `--output_type`, `--seed`, `--attn_mode`, `--blocks_to_swap`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size` function similarly to HunyuanVideo/Wan2.1 where applicable.
296
+
297
+ `--output_type` supports `latent_images` in addition to the options available in HunyuanVideo/Wan2.1. This option saves the latent and image files in the specified directory.
298
+
299
+ The LoRA weights that can be specified in `--lora_weight` are not limited to the FramePack weights trained in this repository. You can also specify the HunyuanVideo LoRA weights from this repository and the HunyuanVideo LoRA weights from diffusion-pipe (automatic detection).
300
+
301
+ The maximum value for `--blocks_to_swap` is 38.
302
+
303
+ <details>
304
+ <summary>日本語</summary>
305
+
306
+ FramePackの推論は専用のスクリプト`fpack_generate_video.py`を使用します。コマンド記述例は英語版を参考にしてください。
307
+
308
+ HunyuanVideoの推論との主な違いは次のとおりです。
309
+ - `fpack_generate_video.py`を使用します。
310
+ - `--f1`を指定すると、FramePack-F1モデルの推論を行います(順方向で生成)。`--dit`にFramePack-F1モデルを指定する必要があります。
311
+ - `--vae`、`--text_encoder1`、`--text_encoder2`、`--image_encoder`を指定する必要があります。
312
+ - `--image_path`を指定する必要があります(開始フレーム)。
313
+ - `--video_seconds` または `--video_sections` を指定する必要があります。`--video_seconds`は秒単位でのビデオの長さを指定し、`--video_sections`はセクション数を指定します。`--video_sections`を指定した場合、`--video_seconds`は無視されます。
314
+ - `--video_size`は生成するビデオのサイズで、高さと幅をその順番で指定します。
315
+ - `--prompt`: 生成用のプロンプトです。
316
+ - 必要に応じて`--latent_window_size`引数(デフォルト9)を指定できます(キャッシング時、学習時と一致させる必要があります)。
317
+ - DiTのメモリ使用量を削減するために、`--fp8_scaled`オプションを指定可能です。品質はやや低下する可能性があります。またText Encoder 1のメモリ使用量を削減するために、`--fp8_llm`オプションを指定可能です。DiT用に`--fp8`単独のオプションも用意されていますが、`--fp8_scaled`の方が品質が良い可能性があります。
318
+ - LoRAの読み込みオプション(`--lora_weight`、`--lora_multiplier`、`--include_patterns`、`--exclude_patterns`)が利用可能です。LyCORISもサポートされています。
319
+ - `--embedded_cfg_scale`(デフォルト10.0)は、蒸留されたガイダンススケールを制御します。通常は変更しないでください。
320
+ - `--guidance_scale`(デフォルト1.0)は、標準の分類器フリーガイダンススケールを制御します。**FramePackモデルのベースモデルでは、通常1.0から変更しないことをお勧めします。**
321
+ - `--guidance_rescale`(デフォルト0.0)も利用可能ですが、通常は必要ありません。
322
+ - `--bulk_decode`オプションは、すべてのフレームを一度にデコードできるオプションです。高速ですが、デコード中にVRAMを多く使用します。VRAM不足エラーを防ぐために、`--vae_chunk_size`と`--vae_spatial_tile_sample_min_size`オプションを指定することをお勧めします。
323
+ - `--sample_solver`(デフォルト`unipc`)は利用可能ですが、`unipc`のみが実装されています。
324
+ - `--save_merged_model`オプションは、LoRAの重みをマージした後にDiTモデルを保存するためのオプションです。これを指定すると推論はスキップされます。
325
+ - `--latent_paddings`オプションは、各セクションのデフォルトのパディングを上書きします。カンマ区切りの整数リストとして指定します。例:`--latent_paddings 0,0,0,0`。`--f1`を指定した場合は無視されます。
326
+ - `--custom_system_prompt`オプションは、LLaMA Text Encoder 1のデフォルトのシステムプロンプトを上書きします。文字列として指定します。デフォルトのシステムプロンプトは[こちら](../hunyuan_model/text_encoder.py#L152)を参照してください。
327
+ - `--rope_scaling_timestep_threshold`オプションはRoPEスケーリングのタイムステップ閾値で、デフォルトはNone(無効)です。設定すると、タイムステップが閾値以上の場合にのみRoPEスケーリングが適用されます。800程度から初めて調整してください。1フレーム推論時での使用を想定しており、それ以外の場合は想定していません。
328
+ - `--rope_scaling_factor`オプションはRoPEスケーリング係数で、デフォルトは0.5で、解像度が2倍の場合を想定しています。1.5倍なら0.7程度が良いでしょう。
329
+
330
+ `--video_size`、`--fps`、`--infer_steps`、`--save_path`、`--output_type`、`--seed`、`--attn_mode`、`--blocks_to_swap`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`などの他のオプションは、HunyuanVideo/Wan2.1と同様に機能します。
331
+
332
+ `--lora_weight`に指定できるLoRAの重みは、当リポジトリで学習したFramePackの重み以外に、当リポジトリのHunyuanVideoのLoRA、diffusion-pipeのHunyuanVideoのLoRAが指定可能です(自動判定)。
333
+
334
+ `--blocks_to_swap`の最大値は38です。
335
+ </details>
336
+
337
+ ## Batch and Interactive Modes / バッチモードとインタラクティブモード
338
+
339
+ In addition to single video generation, FramePack now supports batch generation from file and interactive prompt input:
340
+
341
+ ### Batch Mode from File / ファイルからのバッチモード
342
+
343
+ Generate multiple videos from prompts stored in a text file:
344
+
345
+ ```bash
346
+ python fpack_generate_video.py --from_file prompts.txt
347
+ --dit path/to/dit_model --vae path/to/vae_model.safetensors
348
+ --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
349
+ --image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
350
+ ```
351
+
352
+ The prompts file format:
353
+ - One prompt per line
354
+ - Empty lines and lines starting with # are ignored (comments)
355
+ - Each line can include prompt-specific parameters using command-line style format:
356
+
357
+ ```
358
+ A beautiful sunset over mountains --w 832 --h 480 --f 5 --d 42 --s 20 --i path/to/start_image.jpg
359
+ A busy city street at night --w 480 --h 832 --i path/to/another_start.jpg
360
+ ```
361
+
362
+ Supported inline parameters (if omitted, default values from the command line are used):
363
+ - `--w`: Width
364
+ - `--h`: Height
365
+ - `--f`: Video seconds
366
+ - `--d`: Seed
367
+ - `--s`: Inference steps
368
+ - `--g` or `--l`: Guidance scale
369
+ - `--i`: Image path (for start image)
370
+ - `--im`: Image mask path
371
+ - `--n`: Negative prompt
372
+ - `--vs`: Video sections
373
+ - `--ei`: End image path
374
+ - `--ci`: Control image path (explained in one-frame inference documentation)
375
+ - `--cim`: Control image mask path (explained in one-frame inference documentation)
376
+ - `--of`: One frame inference mode options (same as `--one_frame_inference` in the command line), options for one-frame inference
377
+
378
+ In batch mode, models are loaded once and reused for all prompts, significantly improving overall generation time compared to multiple single runs.
379
+
380
+ ### Interactive Mode / インタラクティブモード
381
+
382
+ Interactive command-line interface for entering prompts:
383
+
384
+ ```bash
385
+ python fpack_generate_video.py --interactive
386
+ --dit path/to/dit_model --vae path/to/vae_model.safetensors
387
+ --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
388
+ --image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
389
+ ```
390
+
391
+ In interactive mode:
392
+ - Enter prompts directly at the command line
393
+ - Use the same inline parameter format as batch mode
394
+ - Use Ctrl+D (or Ctrl+Z on Windows) to exit
395
+ - Models remain loaded between generations for efficiency
396
+
397
+ <details>
398
+ <summary>日本語</summary>
399
+
400
+ 単一動画の生成に加えて、FramePackは現在、ファイルからのバッチ生成とインタラクティブなプロンプト入力をサポートしています。
401
+
402
+ #### ファイルからのバッチモード
403
+
404
+ テキストファイルに保存されたプロンプトから複数の動画を生成します:
405
+
406
+ ```bash
407
+ python fpack_generate_video.py --from_file prompts.txt
408
+ --dit path/to/dit_model --vae path/to/vae_model.safetensors
409
+ --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
410
+ --image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
411
+ ```
412
+
413
+ プロンプトファイルの形式(サンプルは英語ドキュメントを参照):
414
+ - 1行に1つのプロンプト
415
+ - 空行や#で始まる行は無視されます(コメント)
416
+ - 各行にはコマンドライン形式でプロンプト固有のパラメータを含めることができます:
417
+
418
+ サポートされているインラインパラメータ(省略した場合、コマンドラインのデフォルト値が使用されます)
419
+ - `--w`: 幅
420
+ - `--h`: 高さ
421
+ - `--f`: 動画の秒数
422
+ - `--d`: シード
423
+ - `--s`: 推論ステップ
424
+ - `--g` または `--l`: ガイダンススケール
425
+ - `--i`: 画像パス(開始画像用)
426
+ - `--im`: 画像マスクパス
427
+ - `--n`: ネガティブプロンプト
428
+ - `--vs`: 動画セクション数
429
+ - `--ei`: 終了画像パス
430
+ - `--ci`: 制御画像パス(1フレーム推論のドキュメントで解説)
431
+ - `--cim`: 制御画像マスクパス(1フレーム推論のドキュメントで解説)
432
+ - `--of`: 1フレーム推論モードオプション(コマンドラインの`--one_frame_inference`と同様、1フレーム推論のオプション)
433
+
434
+ バッチモードでは、モデルは一度だけロードされ、すべてのプロンプトで再利用されるため、複数回の単一実行と比較して全体的な生成時間が大幅に改善されます。
435
+
436
+ #### インタラクティブモード
437
+
438
+ プロンプトを入力するためのインタラクティブなコマンドラインインターフェース:
439
+
440
+ ```bash
441
+ python fpack_generate_video.py --interactive
442
+ --dit path/to/dit_model --vae path/to/vae_model.safetensors
443
+ --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
444
+ --image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
445
+ ```
446
+
447
+ インタラクティブモードでは:
448
+ - コマンドラインで直接プロンプトを入力
449
+ - バッチモードと同じインラインパラメータ形式を使用
450
+ - 終了するには Ctrl+D (Windowsでは Ctrl+Z) を使用
451
+ - 効率のため、モデルは生成間で読み込まれたままになります
452
+ </details>
453
+
454
+ ## Advanced Video Control Features (Experimental) / 高度なビデオ制御機能(実験的)
455
+
456
+ This section describes experimental features added to the `fpack_generate_video.py` script to provide finer control over the generated video content, particularly useful for longer videos or sequences requiring specific transitions or states. These features leverage the Inverted Anti-drifting sampling method inherent to FramePack.
457
+
458
+ ### **1. End Image Guidance (`--end_image_path`)**
459
+
460
+ * **Functionality:** Guides the generation process to make the final frame(s) of the video resemble a specified target image.
461
+ * **Usage:** `--end_image_path <path_to_image_file>`
462
+ * **Mechanism:** The provided image is encoded using the VAE. This latent representation is used as a target or starting point during the generation of the final video section (which is the first step in Inverted Anti-drifting).
463
+ * **Use Cases:** Defining a clear ending for the video, such as a character striking a specific pose or a product appearing in a close-up.
464
+
465
+ This option is ignored if `--f1` is specified. The end image is not used in the FramePack-F1 model.
466
+
467
+ ### **2. Section Start Image Guidance (`--image_path` Extended Format)**
468
+
469
+ * **Functionality:** Guides specific sections within the video to start with a visual state close to a provided image.
470
+ * You can force the start image by setting `--latent_paddings` to `0,0,0,0` (specify the number of sections as a comma-separated list). If `latent_paddings` is set to 1 or more, the specified image will be used as a reference image (default behavior).
471
+ * **Usage:** `--image_path "SECTION_SPEC:path/to/image.jpg;;;SECTION_SPEC:path/to/another.jpg;;;..."`
472
+ * `SECTION_SPEC`: Defines the target section(s). Rules:
473
+ * `0`: The first section of the video (generated last in Inverted Anti-drifting).
474
+ * `-1`: The last section of the video (generated first).
475
+ * `N` (non-negative integer): The N-th section (0-indexed).
476
+ * `-N` (negative integer): The N-th section from the end.
477
+ * `S-E` (range, e.g., `0-2`): Applies the same image guidance to sections S through E (inclusive).
478
+ * Use `;;;` as a separator between definitions.
479
+ * If no image is specified for a section, generation proceeds based on the prompt and preceding (future time) section context.
480
+ * **Mechanism:** When generating a specific section, if a corresponding start image is provided, its VAE latent representation is strongly referenced as the "initial state" for that section. This guides the beginning of the section towards the specified image while attempting to maintain temporal consistency with the subsequent (already generated) section.
481
+ * **Use Cases:** Defining clear starting points for scene changes, specifying character poses or attire at the beginning of certain sections.
482
+
483
+ ### **3. Section-Specific Prompts (`--prompt` Extended Format)**
484
+
485
+ * **Functionality:** Allows providing different text prompts for different sections of the video, enabling more granular control over the narrative or action flow.
486
+ * **Usage:** `--prompt "SECTION_SPEC:Prompt text for section(s);;;SECTION_SPEC:Another prompt;;;..."`
487
+ * `SECTION_SPEC`: Uses the same rules as `--image_path`.
488
+ * Use `;;;` as a separator.
489
+ * If a prompt for a specific section is not provided, the prompt associated with index `0` (or the closest specified applicable prompt) is typically used. Check behavior if defaults are critical.
490
+ * **Mechanism:** During the generation of each section, the corresponding section-specific prompt is used as the primary textual guidance for the model.
491
+ * **Prompt Content Recommendation** when using `--latent_paddings 0,0,0,0` without `--f1` (original FramePack model):
492
+ * Recall that FramePack uses Inverted Anti-drifting and references future context.
493
+ * It is recommended to describe "**the main content or state change that should occur in the current section, *and* the subsequent events or states leading towards the end of the video**" in the prompt for each section.
494
+ * Including the content of subsequent sections in the current section's prompt helps the model maintain context and overall coherence.
495
+ * Example: For section 1, the prompt might describe what happens in section 1 *and* briefly summarize section 2 (and beyond).
496
+ * However, based on observations (e.g., the `latent_paddings` comment), the model's ability to perfectly utilize very long-term context might be limited. Experimentation is key. Describing just the "goal for the current section" might also work. Start by trying the "section and onwards" approach.
497
+ * Use the default prompt when `latent_paddings` is >= 1 or `--latent_paddings` is not specified, or when using `--f1` (FramePack-F1 model).
498
+ * **Use Cases:** Describing evolving storylines, gradual changes in character actions or emotions, step-by-step processes over time.
499
+
500
+ ### **Combined Usage Example** (with `--f1` not specified)
501
+
502
+ Generating a 3-section video of "A dog runs towards a thrown ball, catches it, and runs back":
503
+
504
+ ```bash
505
+ python fpack_generate_video.py \
506
+ --prompt "0:A dog runs towards a thrown ball, catches it, and runs back;;;1:The dog catches the ball and then runs back towards the viewer;;;2:The dog runs back towards the viewer holding the ball" \
507
+ --image_path "0:./img_start_running.png;;;1:./img_catching.png;;;2:./img_running_back.png" \
508
+ --end_image_path ./img_returned.png \
509
+ --save_path ./output \
510
+ # ... other arguments
511
+ ```
512
+
513
+ * **Generation Order:** Section 2 -> Section 1 -> Section 0
514
+ * **Generating Section 2:**
515
+ * Prompt: "The dog runs back towards the viewer holding the ball"
516
+ * Start Image: `./img_running_back.png`
517
+ * End Image: `./img_returned.png` (Initial target)
518
+ * **Generating Section 1:**
519
+ * Prompt: "The dog catches the ball and then runs back towards the viewer"
520
+ * Start Image: `./img_catching.png`
521
+ * Future Context: Generated Section 2 latent
522
+ * **Generating Section 0:**
523
+ * Prompt: "A dog runs towards a thrown ball, catches it, and runs back"
524
+ * Start Image: `./img_start_running.png`
525
+ * Future Context: Generated Section 1 & 2 latents
526
+
527
+ ### **Important Considerations**
528
+
529
+ * **Inverted Generation:** Always remember that generation proceeds from the end of the video towards the beginning. Section `-1` (the last section, `2` in the example) is generated first.
530
+ * **Continuity vs. Guidance:** While start image guidance is powerful, drastically different images between sections might lead to unnatural transitions. Balance guidance strength with the need for smooth flow.
531
+ * **Prompt Optimization:** The prompt content recommendation is a starting point. Fine-tune prompts based on observed model behavior and desired output quality.
532
+
533
+ <details>
534
+ <summary>日本語</summary>
535
+
536
+ ### **高度な動画制御機能(実験的)**
537
+
538
+ このセクションでは、`fpack_generate_video.py` スクリプトに追加された実験的な機能について説明します。これらの機能は、生成される動画の内容をより詳細に制御するためのもので、特に長い動画や特定の遷移・状態が必要なシーケンスに役立ちます。これらの機能は、FramePack固有のInverted Anti-driftingサンプリング方式を活用しています。
539
+
540
+ #### **1. 終端画像ガイダンス (`--end_image_path`)**
541
+
542
+ * **機能:** 動画の最後のフレーム(群)を指定したターゲット画像に近づけるように生成を誘導します。
543
+ * **書式:** `--end_image_path <画像ファイルパス>`
544
+ * **動作:** 指定された画像はVAEでエンコードされ、その潜在表現が動画の最終セクション(Inverted Anti-driftingでは最初に生成される)の生成時の目標または開始点として使用されます。
545
+ * **用途:** キャラクターが特定のポーズで終わる、特定の商品がクローズアップで終わるなど、動画の結末を明確に定義する場合。
546
+
547
+ このオプションは、`--f1`を指定した場合は無視されます。FramePack-F1モデルでは終端画像は使用されません。
548
+
549
+ #### **2. セクション開始画像ガイダンス (`--image_path` 拡張書式)**
550
+
551
+ * **機能:** 動画内の特定のセクションが、指定された画像に近い視覚状態から始まるように誘導します。
552
+ * `--latent_paddings`を`0,0,0,0`(カンマ区切りでセクション数だけ指定)に設定することで、セクションの開始画像を強制できます。`latent_paddings`が1以上の場合、指定された画像は参照画像として使用されます。
553
+ * **書式:** `--image_path "セクション指定子:画像パス;;;セクション指定子:別の画像パス;;;..."`
554
+ * `セクション指定子`: 対象セクションを定義します。ルール:
555
+ * `0`: 動画の最初のセクション(Inverted Anti-driftingでは最後に生成)。
556
+ * `-1`: 動画の最後のセクション(最初に生成)。
557
+ * `N`(非負整数): N番目のセクション(0始まり)。
558
+ * `-N`(負整数): 最後からN番目のセクション。
559
+ * `S-E`(範囲, 例:`0-2`): セクションSからE(両端含む)に同じ画像を適用。
560
+ * 区切り文字は `;;;` です。
561
+ * セクションに画像が指定されていない場合、プロンプトと後続(未来時刻)セクションのコンテキストに基づいて生成されます。
562
+ * **動作:** 特定セクションの生成時、対応する開始画像が指定されていれば、そのVAE潜在表現がそのセクションの「初期状態」として強く参照されます。これにより、後続(生成済み)セクションとの時間的連続性を維持しようとしつつ、セクションの始まりを指定画像に近づけます。
563
+ * **用途:** シーン変更の起点を明確にする、特定のセクション開始時のキャラクターのポーズや服装を指定するなど。
564
+
565
+ #### **3. セクション別プロンプト (`--prompt` 拡張書式)**
566
+
567
+ * **機能:** 動画のセクションごとに異なるテキストプロンプトを与え、物語やアクションの流れをより細かく指示できます。
568
+ * **書式:** `--prompt "セクション指定子:プロンプトテキスト;;;セクション指定子:別のプロンプト;;;..."`
569
+ * `セクション指定子`: `--image_path` と同じルールです。
570
+ * 区切り文字は `;;;` です。
571
+ * 特定セクションのプロンプトがない場合、通常はインデックス`0`に関連付けられたプロンプト(または最も近い適用可能な指定プロンプト)が使用されます。デフォルトの挙動が重要な場合は確認してくだ��い。
572
+ * **動作:** 各セクションの生成時、対応するセクション別プロンプトがモデルへの主要なテキスト指示として使用されます。
573
+ * `latent_paddings`に`0`を指定した場合(非F1モデル)の **プロンプト内容の推奨:**
574
+ * FramePackはInverted Anti-driftingを採用し、未来のコンテキストを参照することを思い出してください。
575
+ * 各セクションのプロンプトには、「**現在のセクションで起こるべき主要な内容や状態変化、*および*それに続く動画の終端までの内容**」を記述することを推奨します。
576
+ * 現在のセクションのプロンプトに後続セクションの内容を含めることで、モデルが全体的な文脈を把握し、一貫性を保つのに役立ちます。
577
+ * 例:セクション1のプロンプトには、セクション1の内容 *と* セクション2の簡単な要約を記述します。
578
+ * ただし、モデルの長期コンテキスト完全利用能力には限界がある可能性も示唆されています(例:`latent_paddings`コメント)。実験が鍵となります。「現在のセクションの目標」のみを記述するだけでも機能する場合があります。まずは「セクションと以降」アプローチを試すことをお勧めします。
579
+ * 使用するプロンプトは、`latent_paddings`が`1`以上または指定されていない場合、または`--f1`(FramePack-F1モデル)を使用している場合は、通常のプロンプト内容を記述してください。
580
+ * **用途:** 時間経過に伴うストーリーの変化、キャラクターの行動や感情の段階的な変化、段階的なプロセスなどを記述する場合。
581
+
582
+ #### **組み合わせ使用例** (`--f1`未指定時)
583
+
584
+ 「投げられたボールに向かって犬が走り、それを捕まえ、走って戻ってくる」3セクション動画の生成:
585
+ (コマンド記述例は英語版を参考にしてください)
586
+
587
+ * **生成順序:** セクション2 → セクション1 → セクション0
588
+ * **セクション2生成時:**
589
+ * プロンプト: "犬がボールを咥えてこちらに向かって走ってくる"
590
+ * 開始画像: `./img_running_back.png`
591
+ * 終端画像: `./img_returned.png` (初期目標)
592
+ * **セクション1生成時:**
593
+ * プロンプト: "犬がボールを捕まえ、その後こちらに向かって走ってくる"
594
+ * 開始画像: `./img_catching.png`
595
+ * 未来コンテキスト: 生成済みセクション2の潜在表現
596
+ * **セクション0生成時:**
597
+ * プロンプト: "犬が投げられたボールに向かって走り、それを捕まえ、走って戻ってくる"
598
+ * 開始画像: `./img_start_running.png`
599
+ * 未来コンテキスト: 生成済みセクション1 & 2の潜在表現
600
+
601
+ #### **重要な考慮事項**
602
+
603
+ * **逆順生成:** 生成は動画の終わりから始まりに向かって進むことを常に意識してください。セクション`-1`(最後のセクション、上の例では `2`)が最初に生成されます。
604
+ * **連続性とガイダンスのバランス:** 開始画像ガイダンスは強力ですが、セクション間で画像が大きく異なると、遷移が不自然になる可能性があります。ガイダンスの強さとスムーズな流れの必要性のバランスを取ってください。
605
+ * **プロンプトの最適化:** 推奨されるプロンプト内容はあくまでも参考です。モデルの観察された挙動と望ましい出力品質に基づいてプロンプトを微調整してください。
606
+
607
+ </details>
docs/framepack_1f.md ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FramePack One Frame (Single Frame) Inference and Training / FramePack 1フレーム推論と学習
2
+
3
+ ## Overview / 概要
4
+
5
+ This document explains advanced inference and training methods using the FramePack model, particularly focusing on **"1-frame inference"** and its extensions. These features aim to leverage FramePack's flexibility to enable diverse image generation and editing tasks beyond simple video generation.
6
+
7
+ ### The Concept and Development of 1-Frame Inference
8
+
9
+ While FramePack is originally a model for generating sequential video frames (or frame sections), it was discovered that by focusing on its internal structure, particularly how it handles temporal information with RoPE (Rotary Position Embedding), interesting control over single-frame generation is possible.
10
+
11
+ 1. **Basic 1-Frame Inference**:
12
+ * It takes an initial image and a prompt as input, limiting the number of generated frames to just one.
13
+ * In this process, by intentionally setting a large RoPE timestamp (`target_index`) for the single frame to be generated, a single static image can be obtained that reflects temporal and semantic changes from the initial image according to the prompt.
14
+ * This utilizes FramePack's characteristic of being highly sensitive to RoPE timestamps, as it supports bidirectional contexts like "Inverted anti-drifting." This allows for operations similar to natural language-based image editing, albeit in a limited capacity, without requiring additional training.
15
+
16
+ 2. **Kisekaeichi Method (Feature Merging via Post-Reference)**:
17
+ * This method, an extension of basic 1-frame inference, was **proposed by furusu**. In addition to the initial image, it also uses a reference image corresponding to a "next section-start image" (treated as `clean_latent_post`) as input.
18
+ * The RoPE timestamp (`target_index`) for the image to be generated is set to an intermediate value between the timestamps of the initial image and the section-end image.
19
+ * More importantly, masking (e.g., zeroing out specific regions) is applied to the latent representation of each reference image. For example, by setting masks to extract a character's face and body shape from the initial image and clothing textures from the reference image, an image can be generated that fuses the desired features of both, similar to a character "dress-up" or outfit swapping. This method can also be fundamentally achieved without additional training.
20
+
21
+ 3. **1f-mc (one frame multi-control) Method (Proximal Frame Blending)**:
22
+ * This method was **proposed by mattyamonaca**. It takes two reference images as input: an initial image (e.g., at `t=0`) and a subsequent image (e.g., at `t=1`, the first frame of a section), and generates a single image blending their features.
23
+ * Unlike Kisekaeichi, latent masking is typically not performed.
24
+ * To fully leverage this method, additional training using LoRA (Low-Rank Adaptation) is recommended. Through training, the model can better learn the relationship and blending method between the two input images to achieve specific editing effects.
25
+
26
+ ### Integration into a Generalized Control Framework
27
+
28
+ The concepts utilized in the methods above—specifying reference images, manipulating timestamps, and applying latent masks—have been generalized to create a more flexible control framework.
29
+ Users can arbitrarily specify the following elements for both inference and LoRA training:
30
+
31
+ * **Control Images**: Any set of input images intended to influence the model.
32
+ * **Clean Latent Index (Indices)**: Timestamps corresponding to each control image. These are treated as `clean latent index` internally by FramePack and can be set to any position on the time axis. This is specified as `control_index`.
33
+ * **Latent Masks**: Masks applied to the latent representation of each control image, allowing selective control over which features from the control images are utilized. This is specified as `control_image_mask_path` or the alpha channel of the control image.
34
+ * **Target Index**: The timestamp for the single frame to be generated.
35
+
36
+ This generalized control framework, along with corresponding extensions to the inference and LoRA training tools, has enabled advanced applications such as:
37
+
38
+ * Development of LoRAs that stabilize 1-frame inference effects (e.g., a camera orbiting effect) that were previously unstable with prompts alone.
39
+ * Development of Kisekaeichi LoRAs that learn to perform desired feature merging under specific conditions (e.g., ignoring character information from a clothing reference image), thereby automating the masking process through learning.
40
+
41
+ These features maximize FramePack's potential and open up new creative possibilities in static image generation and editing. Subsequent sections will detail the specific options for utilizing these functionalities.
42
+
43
+ <details>
44
+ <summary>日本語</summary>
45
+
46
+ このドキュメントでは、FramePackモデルを用いた高度な推論および学習手法、特に「1フレーム推論」��その拡張機能について解説します。これらの機能は、FramePackの柔軟性を活かし、動画生成に留まらない多様な画像生成・編集タスクを実現することを目的としています。
47
+
48
+ ### 1フレーム推論の発想と発展
49
+
50
+ FramePackは本来、連続する動画フレーム(またはフレームセクション)を生成するモデルですが、その内部構造、特に時間情報を扱うRoPE (Rotary Position Embedding) の扱いに着目することで、単一フレームの生成においても興味深い制御が可能になることが発見されました。
51
+
52
+ 1. **基本的な1フレーム推論**:
53
+ * 開始画像とプロンプトを入力とし、生成するフレーム数を1フレームに限定します。
54
+ * この際、生成する1フレームに割り当てるRoPEのタイムスタンプ(`target_index`)を意図的に大きな値に設定することで、開始画像からプロンプトに従って時間的・意味的に変化した単一の静止画を得ることができます。
55
+ * これは、FramePackがInverted anti-driftingなどの双方向コンテキストに対応するため、RoPEのタイムスタンプに対して敏感に反応する特性を利用したものです。これにより、学習なしで限定的ながら自然言語による画像編集に近い操作が可能です。
56
+
57
+ 2. **kisekaeichi方式 (ポスト参照による特徴マージ)**:
58
+ * 基本的な1フレーム推論を発展させたこの方式は、**furusu氏により提案されました**。開始画像に加え、「次のセクションの開始画像」に相当する参照画像(`clean_latent_post`として扱われる)も入力として利用します。
59
+ * 生成する画像のRoPEタイムスタンプ(`target_index`)を、開始画像のタイムスタンプとセクション終端画像のタイムスタンプの中間的な値に設定します。
60
+ * さらに重要な点として、各参照画像のlatent表現に対してマスク処理(特定領域を0で埋めるなど)を施します。例えば、開始画像からはキャラクターの顔や体型を、参照画像からは服装のテクスチャを抽出するようにマスクを設定することで、キャラクターの「着せ替え」のような、両者の望ましい特徴を融合させた画像を生成できます。この手法も基本的には学習不要で実現可能です。
61
+
62
+ 3. **1f-mc (one frame multi-control) 方式 (近接フレームブレンド)**:
63
+ * この方式は、**mattyamonaca氏により提案されました**。開始画像(例: `t=0`)と、その直後の画像(例: `t=1`、セクションの最初のフレーム)の2つを参照画像として入力し、それらの特徴をブレンドした単一画像を生成します。
64
+ * kisekaeichiとは異なり、latentマスクは通常行いません。
65
+ * この方式の真価を発揮するには、LoRA (Low-Rank Adaptation) による追加学習が推奨されます。学習により、モデルは2つの入力画像間の関係性やブレンド方法をより適切に学習し、特定の編集効果を実現できます。
66
+
67
+ ### 汎用的な制御フレームワークへの統合
68
+
69
+ 上記の各手法で利用されていた「参照画像の指定」「タイムスタンプの操作」「latentマスクの適用」といった概念を一般化し、より柔軟な制御を可能にするための拡張が行われました。
70
+ ユーザーは以下の要素を任意に指定して、推論およびLoRA学習を行うことができます。
71
+
72
+ * **制御画像 (Control Images)**: モデルに影響を与えるための任意の入力画像群。
73
+ * **Clean Latent Index (Indices)**: 各制御画像に対応するタイムスタンプ。FramePack内部の`clean latent index`として扱われ、時間軸上の任意の位置を指定可能です。`control_index`として指定します。
74
+ * **Latentマスク (Latent Masks)**: 各制御画像のlatentに適用するマスク。これにより、制御画像から利用する特徴を選択的に制御します。`control_image_mask_path`または制御画像のアルファチャンネルとして指定します。
75
+ * **Target Index**: 生成したい単一フレームのタイムスタンプ。
76
+
77
+ この汎用的な制御フレームワークと、それに対応した推論ツールおよびLoRA学習ツールの拡張により、以下のような高度な応用が可能になりました。
78
+
79
+ * プロンプトだけでは不安定だった1フレーム推論の効果(例: カメラ旋回)を安定化させるLoRAの開発。
80
+ * マスク処理を手動で行う代わりに、特定の条件下(例: 服の参照画像からキャラクター情報を無視する)で望ましい特徴マージを行うように学習させたkisekaeichi LoRAの開発。
81
+
82
+ これらの機能は、FramePackのポテンシャルを最大限に引き��し、静止画生成・編集における新たな創造の可能性を拓くものです。以降のセクションでは、これらの機能を実際に利用するための具体的なオプションについて説明します。
83
+
84
+ </details>
85
+
86
+ ## One Frame (Single Frame) Training / 1フレーム学習
87
+
88
+ **This feature is experimental.** It trains in the same way as one frame inference.
89
+
90
+ The dataset must be an image dataset. If you use caption files, you need to specify `control_directory` and place the **start images** in that directory. The `image_directory` should contain the images after the change. The filenames of both directories must match. Caption files should be placed in the `image_directory`.
91
+
92
+ If you use JSONL files, specify them as `{"image_path": "/path/to/target_image1.jpg", "control_path": "/path/to/source_image1.jpg", "caption": "The object changes to red."}`. The `image_path` should point to the images after the change, and `control_path` should point to the starting images.
93
+
94
+ For the dataset configuration, see [here](../dataset/dataset_config.md#sample-for-image-dataset-with-control-images) and [here](../dataset/dataset_config.md#framepack-one-frame-training). There are also examples for kisekaeichi and 1f-mc settings.
95
+
96
+ For single frame training, specify `--one_frame` in `fpack_cache_latents.py` to create the cache. You can also use `--one_frame_no_2x` and `--one_frame_no_4x` options, which have the same meaning as `no_2x` and `no_4x` during inference. It is recommended to set these options to match the inference settings.
97
+
98
+ If you change whether to use one frame training or these options, please overwrite the existing cache without specifying `--skip_existing`.
99
+
100
+ Specify `--one_frame` in `fpack_train_network.py` to change the inference method during sample generation.
101
+
102
+ The optimal training settings are currently unknown. Feedback is welcome.
103
+
104
+ ### Example of prompt file description for sample generation
105
+
106
+ The command line options `--one_frame_inference` corresponds to `--of`, and `--control_image_path` corresponds to `--ci`.
107
+
108
+ Note that `--ci` can be specified multiple times, but `--control_image_path` is specified as `--control_image_path img1.png img2.png`, while `--ci` is specified as `--ci img1.png --ci img2.png`.
109
+
110
+ Normal single frame training:
111
+ ```
112
+ The girl wears a school uniform. --i path/to/start.png --ci path/to/start.png --of no_2x,no_4x,target_index=1,control_index=0 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
113
+ ```
114
+
115
+ Kisekaeichi training:
116
+ ```
117
+ The girl wears a school uniform. --i path/to/start_with_alpha.png --ci path/to/ref_with_alpha.png --ci path/to/start_with_alpha.png --of no_post,no_2x,no_4x,target_index=5,control_index=0;10 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
118
+ ```
119
+
120
+ <details>
121
+ <summary>日本語</summary>
122
+
123
+ **この機能は実験的なものです。** 1フレーム推論と同様の方法で学習を行います。
124
+
125
+ データセットは画像データセットである必要があります。キャプションファイルを用いる場合は、`control_directory`を追加で指定し、そのディレクトリに**開始画像**を格納してください。`image_directory`には変化後の画像を格納します。両者のファイル名は一致させる必要があります。キャプションファイルは`image_directory`に格納してください。
126
+
127
+ JSONLファイルを用いる場合は、`{"image_path": "/path/to/target_image1.jpg", "control_path": "/path/to/source_image1.jpg", "caption": "The object changes to red"}`のように指定してください。`image_path`は変化後の画像、`control_path`は開始画像を指定します。
128
+
129
+ データセットの設定については、[こちら](../dataset/dataset_config.md#sample-for-image-dataset-with-control-images)と[こちら](../dataset/dataset_config.md#framepack-one-frame-training)も参照してください。kisekaeichiと1f-mcの設定例もそちらにあります。
130
+
131
+ 1フレーム学習時は、`fpack_cache_latents.py`に`--one_frame`を指定してキャッシュを作成してください。また`--one_frame_no_2x`と`--one_frame_no_4x`オプションも利用可能です。推論時の`no_2x`、`no_4x`と同じ意味を持ちますので、推論時と同じ設定にすることをお勧めします。
132
+
133
+ 1フレーム学習か否かを変更する場合、またこれらのオプションを変更する場合は、`--skip_existing`を指定せずに既存のキャッシュを上書きしてください。
134
+
135
+ また、`fpack_train_network.py`に`--one_frame`を指定してサンプル画像生成時の推論方法を変更してください。
136
+
137
+ 最適な学習設定は今のところ不明です。フィードバックを歓迎します。
138
+
139
+ **サンプル生成のプロンプトファイル記述例**
140
+
141
+ コマンドラインオプション`--one_frame_inference`に相当する `--of`と、`--control_image_path`に相当する`--ci`が用意されています。
142
+
143
+ ※ `--ci`は複数指定可能ですが、`--control_image_path`は`--control_image_path img1.png img2.png`のようにスペースで区切るのに対して、`--ci`は`--ci img1.png --ci img2.png`のように指定するので注意してください。
144
+
145
+ 通常の1フレーム学習:
146
+ ```
147
+ The girl wears a school uniform. --i path/to/start.png --ci path/to/start.png --of no_2x,no_4x,target_index=1,control_index=0 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
148
+ ```
149
+
150
+ kisekaeichi方式:
151
+ ```
152
+ The girl wears a school uniform. --i path/to/start_with_alpha.png --ci path/to/ref_with_alpha.png --ci path/to/start_with_alpha.png --of no_post,no_2x,no_4x,target_index=5,control_index=0;10 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
153
+ ```
154
+
155
+ </details>
156
+
157
+ ## One (single) Frame Inference / 1フレーム推論
158
+
159
+ **This feature is highly experimental** and not officially supported. It is intended for users who want to explore the potential of FramePack for one frame inference, which is not a standard feature of the model.
160
+
161
+ This script also allows for one frame inference, which is not an official feature of FramePack but rather a custom implementation.
162
+
163
+ Theoretically, it generates an image after a specified time from the starting image, following the prompt. This means that, although limited, it allows for natural language-based image editing.
164
+
165
+ To perform one frame inference, specify some option in the `--one_frame_inference` option. Here is an example:
166
+
167
+ ```bash
168
+ --video_sections 1 --output_type latent_images --one_frame_inference default
169
+ ```
170
+
171
+ The `--one_frame_inference` option is recommended to be set to `default` or `no_2x,no_4x`. If you specify `--output_type` as `latent_images`, both the latent and image will be saved.
172
+
173
+ You can specify the following strings in the `--one_frame_inference` option, separated by commas:
174
+
175
+ - `no_2x`: Generates without passing clean latents 2x with zero vectors to the model. Slightly improves generation speed. The impact on generation results is unknown.
176
+ - `no_4x`: Generates without passing clean latents 4x with zero vectors to the model. Slightly improves generation speed. The impact on generation results is unknown.
177
+ - `no_post`: Generates without passing clean latents post with zero vectors to the model. Improves generation speed by about 20%, but may result in unstable generation.
178
+ - `target_index=<integer>`: Specifies the index of the image to be generated. The default is the last frame (i.e., `latent_window_size`).
179
+
180
+ For example, you can use `--one_frame_inference default` to pass clean latents 2x, clean latents 4x, and post to the model. `--one_frame_inference no_2x,no_4x` if you want to skip passing clean latents 2x and 4x to the model. `--one_frame_inference target_index=9` can be used to specify the target index for the generated image.
181
+
182
+ The `--one_frame_inference` option also supports advanced inference, which is described in the next section. This option allows for more detailed control using additional parameters like `target_index` and `control_index` within this option.
183
+
184
+ Normally, specify `--video_sections 1` to indicate only one section (one image).
185
+
186
+ Increasing `target_index` from the default of 9 may result in larger changes. It has been confirmed that generation can be performed without breaking up to around 40.
187
+
188
+ The `--end_image_path` is ignored for one frame inference.
189
+
190
+ <details>
191
+ <summary>日本語</summary>
192
+
193
+ **この機能は非常に実験的であり**、公式にはサポートされていません。FramePackを使用して1フレーム推論の可能性を試したいユーザーに向けたものです。
194
+
195
+ このスクリプトでは、単一画像の推論を行うこともできます。FramePack公式の機能ではなく、独自の実装です。
196
+
197
+ 理論的には、開始画像から、プロンプトに従い、指定時間経過後の画像を生成します。つまり制限付きですが自然言語による画像編集を行うことができます。
198
+
199
+ 単一画像推論を行うには`--one_frame_inference`オプションに、何らかのオプションを指定してください。記述例は以下の通りです。
200
+
201
+ ```bash
202
+ --video_sections 1 --output_type latent_images --one_frame_inference default
203
+ ```
204
+
205
+ `--one_frame_inference`のオプションは、`default`または `no_2x,no_4x`を推奨します。`--output_type`に`latent_images`を指定するとlatentと画像の両方が保存されます。
206
+
207
+ `--one_frame_inference`のオプションには、カンマ区切りで以下のオプションを任意個数指定できます。
208
+
209
+ - `no_2x`: ゼロベクトルの clean latents 2xをモデルに渡さずに生成します。わずかに生成速度が向上します。生成結果への影響は不明です。
210
+ - `no_4x`: ゼロベクトルの clean latents 4xをモデルに渡さずに生成します。わずかに生成速度が向上します。生成結果への影響は不明です。
211
+ - `no_post`: ゼロベクトルの clean latents の post を渡さずに生成します。生成速度が20%程度向上しますが、生成結果が不安定に���る場合があります。
212
+ - `target_index=<整数>`: 生成する画像のindexを指定します。デフォルトは最後のフレームです(=latent_window_size)。
213
+
214
+ たとえば、`--one_frame_inference default`を使用すると、clean latents 2x、clean latents 4x、postをモデルに渡します。`--one_frame_inference no_2x,no_4x`を使用すると、clean latents 2xと4xをモデルに渡すのをスキップします。`--one_frame_inference target_index=9`を使用して、生成する画像のターゲットインデックスを指定できます。
215
+
216
+ 後述の高度な推論では、このオプション内で `target_index`、`control_index` といった追加のパラメータを指定して、より詳細な制御が可能です。
217
+
218
+ clean latents 2x、clean latents 4x、postをモデルに渡す場合でも値はゼロベクトルですが、値を渡すか否かで結果は変わります。特に`no_post`を指定すると、`latent_window_size`を大きくしたときに生成結果が不安定になる場合があります。
219
+
220
+ 通常は`--video_sections 1` として1セクションのみ(画像1枚)を指定してください。
221
+
222
+ `target_index` をデフォルトの9から大きくすると、変化量が大きくなる可能性があります。40程度までは破綻なく生成されることを確認しています。
223
+
224
+ `--end_image_path`は無視されます。
225
+
226
+ </details>
227
+
228
+ ## kisekaeichi method (Post Reference Options) and 1f-mc (Multi-Control) / kisekaeichi方式(ポスト参照オプション)と1f-mc(マルチコントロール)
229
+
230
+ The `kisekaeichi` method was proposed by furusu. The `1f-mc` method was proposed by mattyamonaca in pull request [#304](https://github.com/kohya-ss/musubi-tuner/pull/304).
231
+
232
+ In this repository, these methods have been integrated and can be specified with the `--one_frame_inference` option. This allows for specifying any number of control images as clean latents, along with indices. This means you can specify multiple starting images and multiple clean latent posts. Additionally, masks can be applied to each image.
233
+
234
+ It is expected to work only with FramePack (non-F1 model) and not with F1 models.
235
+
236
+ The following options have been added to `--one_frame_inference`. These can be used in conjunction with existing flags like `target_index`, `no_post`, `no_2x`, and `no_4x`.
237
+
238
+ - `control_index=<integer_or_semicolon_separated_integers>`: Specifies the index(es) of the clean latent for the control image(s). You must specify the same number of indices as the number of control images specified with `--control_image_path`.
239
+
240
+ Additionally, the following command-line options have been added. These arguments are only valid when `--one_frame_inference` is specified.
241
+
242
+ - `--control_image_path <path1> [<path2> ...]` : Specifies the path(s) to control (reference) image(s) for one frame inference. Provide one or more paths separated by spaces. Images with an alpha channel can be specified. If an alpha channel is present, it is used as a mask for the clean latent.
243
+ - `--control_image_mask_path <path1> [<path2> ...]` : Specifies the path(s) to grayscale mask(s) to be applied to the control image(s). Provide one or more paths separated by spaces. Each mask is applied to the corresponding control image. The 255 areas are referenced, while the 0 areas are ignored.
244
+
245
+ **Example of specifying kisekaeichi:**
246
+
247
+ The kisekaeichi method works without training, but using a dedicated LoRA may yield better results.
248
+
249
+ ```bash
250
+ --video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png clean_latent_post_image.png \
251
+ --one_frame_inference target_index=1,control_index=0;10,no_post,no_2x,no_4x --control_image_mask_path ctrl_mask1.png ctrl_mask2.png
252
+ ```
253
+
254
+ In this example, `start_image.png` (for `clean_latent_pre`) and `clean_latent_post_image.png` (for `clean_latent_post`) are the reference images. The `target_index` specifies the index of the generated image. The `control_index` specifies the clean latent index for each control image, so it will be `0;10`. The masks for the control images are specified with `--control_image_mask_path`.
255
+
256
+ The optimal values for `target_index` and `control_index` are unknown. The `target_index` should be specified as 1 or higher. The `control_index` should be set to an appropriate value relative to `latent_window_size`. Specifying 1 for `target_index` results in less change from the starting image, but may introduce noise. Specifying 9 or 13 may reduce noise but result in larger changes from the original image.
257
+
258
+ The `control_index` should be larger than `target_index`. Typically, it is set to `10`, but larger values (e.g., around `13-16`) may also work.
259
+
260
+ Sample images and command lines for reproduction are as follows:
261
+
262
+ ```bash
263
+ python fpack_generate_video.py --video_size 832 480 --video_sections 1 --infer_steps 25 \
264
+ --prompt "The girl in a school blazer in a classroom." --save_path path/to/output --output_type latent_images \
265
+ --dit path/to/dit --vae path/to/vae --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2 \
266
+ --image_encoder path/to/image_encoder --attn_mode sdpa --vae_spatial_tile_sample_min_size 128 --vae_chunk_size 32 \
267
+ --image_path path/to/kisekaeichi_start.png --control_image_path path/to/kisekaeichi_start.png path/to/kisekaeichi_ref.png
268
+ --one_frame_inference target_index=1,control_index=0;10,no_2x,no_4x,no_post
269
+ --control_image_mask_path path/to/kisekaeichi_start_mask.png path/to/kisekaeichi_ref_mask.png --seed 1234
270
+ ```
271
+
272
+ Specify `--fp8_scaled` and `--blocks_to_swap` options according to your VRAM capacity.
273
+
274
+ - [kisekaeichi_start.png](./kisekaeichi_start.png)
275
+ - [kisekaeichi_ref.png](./kisekaeichi_ref.png)
276
+ - [kisekaeichi_start_mask.png](./kisekaeichi_start_mask.png)
277
+ - [kisekaeichi_ref_mask.png](./kisekaeichi_ref_mask.png)
278
+
279
+ Generation result: [kisekaeichi_result.png](./kisekaeichi_result.png)
280
+
281
+
282
+ **Example of 1f-mc (Multi-Control):**
283
+
284
+ ```bash
285
+ --video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png 2nd_image.png \
286
+ --one_frame_inference target_index=9,control_index=0;1,no_2x,no_4x
287
+ ```
288
+
289
+ In this example, `start_image.png` is the starting image, and `2nd_image.png` is the reference image. The `target_index=9` specifies the index of the generated image, while `control_index=0;1` specifies the clean latent indices for each control image.
290
+
291
+ 1f-mc is intended to be used in combination with a trained LoRA, so adjust `target_index` and `control_index` according to the LoRA's description.
292
+
293
+ <details>
294
+ <summary>日本語</summary>
295
+
296
+ `kisekaeichi`方式はfurusu氏により提案されました。また`1f-mc`方式はmattyamonaca氏によりPR [#304](https://github.com/kohya-ss/musubi-tuner/pull/304) で提案されました。
297
+
298
+ 当リポジトリではこれらの方式を統合し、`--one_frame_inference`オプションで指定できるようにしました。これにより、任意の枚数の制御用画像を clean latentとして指定し、さらにインデックスを指定できます。つまり開始画像の複数枚指定やclean latent postの複数枚指定などが可能です。また、それぞれの画像にマスクを適用することもできます。
299
+
300
+ なお、FramePack無印のみ動作し、F1モデルでは動作しないと思われます。
301
+
302
+ `--one_frame_inference`に以下のオプションが追加されています。`target_index`、`no_post`、`no_2x`や`no_4x`など既存のフラグと併用できます。
303
+
304
+ - `control_index=<整数またはセミコロン区切りの整数>`: 制御用画像のclean latentのインデックスを指定します。`--control_image_path`で指定した制御用画像の数と同じ数のインデックスを指定してください。
305
+
306
+ またコマンドラインオプションに以下が追加されています。これらの引数は`--one_frame_inference`を指定した場合のみ有効です。
307
+
308
+ - `--control_image_path <パス1> [<パス2> ...]` : 1フレーム推論用の制御用(参照)画像のパスを1つ以上、スペース区切りで指定します。アルファチャンネルを持つ画像が指定可能です。アルファチャンネルがある場合は、clean latentへのマスクとして利用されます。
309
+ - `--control_image_mask_path <パス1> [<パス2> ...]` : 制御用画像に適用するグレースケールマスクのパスを1つ以上、スペース区切りで指定します。各マスクは対応する制御用画像に適用されます。255の部分が参照される部分、0の部分が無視される部分です。
310
+
311
+ **kisekaeichiの指定例**:
312
+
313
+ kisekaeichi方式は学習なしでも動作しますが、専用のLoRAを使用することで、より良い結果が得られる可能性があります。
314
+
315
+ ```bash
316
+ --video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png clean_latent_post_image.png \
317
+ --one_frame_inference target_index=1,control_index=0;10,no_post,no_2x,no_4x --control_image_mask_path ctrl_mask1.png ctrl_mask2.png
318
+ ```
319
+
320
+ `start_image.png`(clean_latent_preに相当)と`clean_latent_post_image.png`は参照画像(clean_latent_postに相当)です。`target_index`は生成する画像のインデックスを指定します。`control_index`はそれぞれの制御用画像のclean latent indexを指定しますので、`0;10` になります。また`--control_image_mask_path`に制御用画像に適用するマスクを指定します。
321
+
322
+ `target_index`、`control_index`の最適値は不明です。`target_index`は1以上を指定してください。`control_index`は`latent_window_size`に対して適切な値を指定してください。`target_index`に1を指定すると開始画像からの変化が少なくなりますが、ノイズが乗ったりすることが多いようです。9や13などを指定するとノイズは改善されるかもしれませんが、元の画像からの変化が大きくなります。
323
+
324
+ `control_index`は`target_index`より大きい値を指定してください。通常は`10`ですが、これ以上大きな値、たとえば`13~16程度でも動作するようです。
325
+
326
+ サンプル画像と再現のためのコマンドラインは以下のようになります。
327
+
328
+ ```bash
329
+ python fpack_generate_video.py --video_size 832 480 --video_sections 1 --infer_steps 25 \
330
+ --prompt "The girl in a school blazer in a classroom." --save_path path/to/output --output_type latent_images \
331
+ --dit path/to/dit --vae path/to/vae --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2 \
332
+ --image_encoder path/to/image_encoder --attn_mode sdpa --vae_spatial_tile_sample_min_size 128 --vae_chunk_size 32 \
333
+ --image_path path/to/kisekaeichi_start.png --control_image_path path/to/kisekaeichi_start.png path/to/kisekaeichi_ref.png
334
+ --one_frame_inference target_index=1,control_index=0;10,no_2x,no_4x,no_post
335
+ --control_image_mask_path path/to/kisekaeichi_start_mask.png path/to/kisekaeichi_ref_mask.png --seed 1234
336
+ ```
337
+
338
+ VRAM容量に応じて、`--fp8_scaled`や`--blocks_to_swap`等のオプションを調整してください。
339
+
340
+ - [kisekaeichi_start.png](./kisekaeichi_start.png)
341
+ - [kisekaeichi_ref.png](./kisekaeichi_ref.png)
342
+ - [kisekaeichi_start_mask.png](./kisekaeichi_start_mask.png)
343
+ - [kisekaeichi_ref_mask.png](./kisekaeichi_ref_mask.png)
344
+
345
+ 生成結果:
346
+ - [kisekaeichi_result.png](./kisekaeichi_result.png)
347
+
348
+ **1f-mcの指定例**:
349
+
350
+ ```bash
351
+ --video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png 2nd_image.png \
352
+ --one_frame_inference target_index=9,control_index=0;1,no_2x,no_4x
353
+ ```
354
+
355
+ この例では、`start_image.png`が開始画像で、`2nd_image.png`が参照画像です。`target_index=9`は生成する画像のインデックスを指定し、`control_index=0;1`はそれぞれの制御用画像のclean latent indexを指定しています。
356
+
357
+ 1f-mcは学習したLoRAと組み合わせることを想定していますので、そのLoRAの説明に従って、`target_index`や`control_index`を調整してください。
358
+
359
+ </details>
docs/kisekaeichi_ref.png ADDED

Git LFS Details

  • SHA256: e5037f0a0cfb1a6b0a8d1f19fb462df75fb53384d0d9e654c359ca984fafa605
  • Pointer size: 131 Bytes
  • Size of remote file: 584 kB
docs/kisekaeichi_ref_mask.png ADDED
docs/kisekaeichi_result.png ADDED

Git LFS Details

  • SHA256: 223dacb98ac834a442ee124641a6b852b1cde3bc1f11939e78192fc8be2f7b49
  • Pointer size: 131 Bytes
  • Size of remote file: 408 kB
docs/kisekaeichi_start.png ADDED

Git LFS Details

  • SHA256: beee4a910402ef2798b00aa4d193b0b7186380ed24928a4d39acc8635d2cfdaf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
docs/kisekaeichi_start_mask.png ADDED
docs/sampling_during_training.md CHANGED
@@ -72,16 +72,20 @@ A line starting with `#` is a comment.
72
  * `--f` specifies the number of frames. The default is 1, which generates a still image.
73
  * `--d` specifies the seed. The default is random.
74
  * `--s` specifies the number of steps in generation. The default is 20.
75
- * `--g` specifies the guidance scale. The default is 6.0, which is the default value during inference of HunyuanVideo. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
76
- * `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10).
77
 
78
- If you train I2V models, you can use the additional options below.
79
 
80
  * `--i path/to/image.png`: the image path for image2video inference.
81
 
82
- If you train the model with classifier free guidance, you can use the additional options below.
83
 
84
- *`--n negative prompt...`: the negative prompt for the classifier free guidance.
 
 
 
 
85
  *`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
86
 
87
  <details>
@@ -94,15 +98,19 @@ If you train the model with classifier free guidance, you can use the additional
94
  * `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
95
  * `--d` シードを指定します。省略時はランダムです。
96
  * `--s` 生成におけるステップ数を指定します。省略時は20です。
97
- * `--g` guidance scaleを指定します。省略時は6.0で、HunyuanVideoの推論時のデフォルト値です。
98
- * `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
99
 
100
- I2Vモデルを学習する場合、以下の追加オプションを使用できます。
101
 
102
  * `--i path/to/image.png`: image2video推論用の画像パス。
103
 
104
- classifier free guidance(ネガティブプロンプト)を必要とするモデルを学習する場合、以下の追加オプションを使用できます。
 
 
 
 
105
 
106
- *`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。
107
  *`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
108
  </details>
 
72
  * `--f` specifies the number of frames. The default is 1, which generates a still image.
73
  * `--d` specifies the seed. The default is random.
74
  * `--s` specifies the number of steps in generation. The default is 20.
75
+ * `--g` specifies the embedded guidance scale (not CFG scale). The default is 6.0 for HunyuanVideo, 10.0 for FramePack, which is the default value during inference of each architecture. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
76
+ * `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10). Ignore this option for FramePack models (it uses 10.0).
77
 
78
+ If you train I2V models, you must add the following option.
79
 
80
  * `--i path/to/image.png`: the image path for image2video inference.
81
 
82
+ If you train Wan2.1-Fun-Control models, you must add the following option.
83
 
84
+ * `--cn path/to/control_video_or_dir_of_images`: the path to the video or directory containing multiple images for control.
85
+
86
+ If you train the model with classifier free guidance (such as Wan2.1), you can use the additional options below.
87
+
88
+ *`--n negative prompt...`: the negative prompt for the classifier free guidance. The default prompt for each model is used if omitted.
89
  *`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
90
 
91
  <details>
 
98
  * `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
99
  * `--d` シードを指定します。省略時はランダムです。
100
  * `--s` 生成におけるステップ数を指定します。省略時は20です。
101
+ * `--g` embedded guidance scaleを指定します(CFG scaleではありません)。省略時はHunyuanVideoは6.0、FramePackは10.0で、各アーキテクチャの推論時のデフォルト値です。SkyReels V1モデルの場合は1.0を指定してください。Wan2.1モデルの場合はこのオプションは無視されます。
102
+ * `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。FramePackモデルはこのオプションは無視され、10.0が使用されます。
103
 
104
+ I2Vモデルを学習する場合、以下のオプションを追加してください。
105
 
106
  * `--i path/to/image.png`: image2video推論用の画像パス。
107
 
108
+ Wan2.1-Fun-Controlモデルを学習する場合、以下のオプションを追加してください。
109
+
110
+ * `--cn path/to/control_video_or_dir_of_images`: control用の動画または複数枚の画像を含むディレクトリのパス。
111
+
112
+ classifier free guidance(ネガティブプロンプト)を必要とするモデル(Wan2.1など)を学習する場合、以下の追加オプションを使用できます。
113
 
114
+ *`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。省略時はモデルごとのデフォルトプロンプトが使用されます。
115
  *`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
116
  </details>
docs/wan.md CHANGED
@@ -27,24 +27,45 @@ This feature is experimental.
27
 
28
  ## Download the model / モデルのダウンロード
29
 
30
- Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
31
 
32
  Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
33
 
34
  Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
35
 
36
- Please select the appropriate weights according to T2V, I2V, resolution, model size, etc. fp8 models can be used if `--fp8` is specified.
 
 
 
 
37
 
38
  (Thanks to Comfy-Org for providing the repackaged weights.)
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  <details>
40
  <summary>日本語</summary>
41
- T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
42
 
43
  VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
44
 
45
  DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
46
 
47
- T2VI2V、解像度、モデルサイズなどにより適切な重みを選択してください。`--fp8`指定時はfp8モデルも使用できます。
 
 
 
 
48
 
49
  (repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
50
  </details>
@@ -63,6 +84,8 @@ If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-l
63
 
64
  If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
65
 
 
 
66
  <details>
67
  <summary>日本語</summary>
68
  latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
@@ -70,6 +93,8 @@ latentの事前キャッシングはHunyuanVideoとほぼ同じです。上の
70
  I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
71
 
72
  VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
 
 
73
  </details>
74
 
75
  ### Text Encoder Output Pre-caching
@@ -115,7 +140,7 @@ The above is an example. The appropriate values for `timestep_sampling` and `dis
115
 
116
  For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
117
 
118
- `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`. Specify the DiT weights for the task with `--dit`.
119
 
120
  Don't forget to specify `--network_module networks.lora_wan`.
121
 
@@ -129,7 +154,7 @@ Use `convert_lora.py` for converting the LoRA weights after training, as in Huny
129
 
130
  その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
131
 
132
- `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
133
 
134
  `--network_module` に `networks.lora_wan` を指定することを忘れないでください。
135
 
@@ -152,7 +177,7 @@ Each option is the same as when generating images or as HunyuanVideo. Please ref
152
 
153
  If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
154
 
155
- You can specify the initial image and negative prompts in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
156
 
157
  <details>
158
  <summary>日本語</summary>
@@ -160,12 +185,23 @@ You can specify the initial image and negative prompts in the prompt file. Pleas
160
 
161
  I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
162
 
163
- プロンプトファイルで、初期画像やネガティブプロンプト等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
164
  </details>
165
 
166
 
167
  ## Inference / 推論
168
 
 
 
 
 
 
 
 
 
 
 
 
169
  ### T2V Inference / T2V推論
170
 
171
  The following is an example of T2V inference (input as a single line):
@@ -178,30 +214,64 @@ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video
178
  --attn_mode torch
179
  ```
180
 
181
- `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`.
182
 
183
  `--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
184
 
 
 
 
 
 
 
185
  `--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
186
 
187
  `--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
188
 
189
- ` --flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
190
 
191
- `--guidance_scale` can be used to specify the guidance scale for classifier free guiance (default 5.0).
192
 
193
  `--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
194
 
195
  `--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
198
 
199
  <details>
200
  <summary>日本語</summary>
201
- `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。
202
 
203
  `--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
204
 
 
 
 
 
 
 
205
  `--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
206
 
207
  `--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
@@ -214,9 +284,116 @@ Other options are same as `hv_generate_video.py` (some options are not supported
214
 
215
  `--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
218
  </details>
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  ### I2V Inference / I2V推論
221
 
222
  The following is an example of I2V inference (input as a single line):
@@ -231,11 +408,124 @@ python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_l
231
 
232
  Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
233
 
 
 
 
 
234
  Other options are same as T2V inference.
235
 
236
  <details>
237
  <summary>日本語</summary>
238
  `--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
239
 
 
 
 
 
240
  その他のオプションはT2V推論と同じです。
241
  </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ## Download the model / モデルのダウンロード
29
 
30
+ Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main
31
 
32
  Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
33
 
34
  Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
35
 
36
+ Wan2.1 Fun Control model weights can be downloaded from [here](https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control). Navigate to each weight page and download. The Fun Control model seems to support not only T2V but also I2V tasks.
37
+
38
+ Please select the appropriate weights according to T2V, I2V, resolution, model size, etc.
39
+
40
+ `fp16` and `bf16` models can be used, and `fp8_e4m3fn` models can be used if `--fp8` (or `--fp8_base`) is specified without specifying `--fp8_scaled`. **Please note that `fp8_scaled` models are not supported even with `--fp8_scaled`.**
41
 
42
  (Thanks to Comfy-Org for providing the repackaged weights.)
43
+
44
+ ### Model support matrix / モデルサポートマトリックス
45
+
46
+ * columns: training dtype (行:学習時のデータ型)
47
+ * rows: model dtype (列:モデルのデータ型)
48
+
49
+ | model \ training |bf16|fp16|--fp8_base|--fp8base & --fp8_scaled|
50
+ |--|--|--|--|--|
51
+ |bf16|✓|--|✓|✓|
52
+ |fp16|--|✓|✓|✓|
53
+ |fp8_e4m3fn|--|--|✓|--|
54
+ |fp8_scaled|--|--|--|--|
55
+
56
  <details>
57
  <summary>日本語</summary>
58
+ T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main
59
 
60
  VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
61
 
62
  DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
63
 
64
+ Wan2.1 Fun Controlモデルの重みは、[こちら](https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control)から、それぞれの重みのページに遷移し、ダウンロードしてください。Fun ControlモデルはT2VだけでなくI2Vタスクにも対応しているようです。
65
+
66
+ T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。
67
+
68
+ `fp16` および `bf16` モデルを使用できます。また、`--fp8` (または`--fp8_base`)を指定し`--fp8_scaled`を指定をしないときには `fp8_e4m3fn` モデルを使用できます。**`fp8_scaled` モデルはいずれの場合もサポートされていませんのでご注意ください。**
69
 
70
  (repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
71
  </details>
 
84
 
85
  If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
86
 
87
+ The control video settings are required for training the Fun-Control model. Please refer to [Dataset Settings](/dataset/dataset_config.md#sample-for-video-dataset-with-control-images) for details.
88
+
89
  <details>
90
  <summary>日本語</summary>
91
  latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
 
93
  I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
94
 
95
  VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
96
+
97
+ Fun-Controlモデルを学習する場合は、制御用動画の設定が必要です。[データセット設定](/dataset/dataset_config.md#sample-for-video-dataset-with-control-images)を参照してください。
98
  </details>
99
 
100
  ### Text Encoder Output Pre-caching
 
140
 
141
  For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
142
 
143
+ `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (for Wan2.1 official models), `t2v-1.3B-FC`, `t2v-14B-FC`, and `i2v-14B-FC` (for Wan2.1 Fun Control model). Specify the DiT weights for the task with `--dit`.
144
 
145
  Don't forget to specify `--network_module networks.lora_wan`.
146
 
 
154
 
155
  その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
156
 
157
+ `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (これらはWan2.1公式モデル)、`t2v-1.3B-FC`, `t2v-14B-FC`, `i2v-14B-FC`(Wan2.1-Fun Controlモデル)を指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
158
 
159
  `--network_module` に `networks.lora_wan` を指定することを忘れないでください。
160
 
 
177
 
178
  If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
179
 
180
+ You can specify the initial image, the negative prompt and the control video (for Wan2.1-Fun-Control) in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
181
 
182
  <details>
183
  <summary>日本語</summary>
 
185
 
186
  I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
187
 
188
+ プロンプトファイルで、初期画像やネガティブプロンプト、制御動画(Wan2.1-Fun-Control用)等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
189
  </details>
190
 
191
 
192
  ## Inference / 推論
193
 
194
+ ### Inference Options Comparison / 推論オプション比較
195
+
196
+ #### Speed Comparison (Faster → Slower) / 速度比較(速い→遅い)
197
+ *Note: Results may vary depending on GPU type*
198
+
199
+ fp8_fast > bf16/fp16 (no block swap) > fp8 > fp8_scaled > bf16/fp16 (block swap)
200
+
201
+ #### Quality Comparison (Higher → Lower) / 品質比較(高→低)
202
+
203
+ bf16/fp16 > fp8_scaled > fp8 >> fp8_fast
204
+
205
  ### T2V Inference / T2V推論
206
 
207
  The following is an example of T2V inference (input as a single line):
 
214
  --attn_mode torch
215
  ```
216
 
217
+ `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (these are Wan2.1 official models), `t2v-1.3B-FC`, `t2v-14B-FC` and `i2v-14B-FC` (for Wan2.1-Fun Control model).
218
 
219
  `--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
220
 
221
+ Specifying `--fp8` runs DiT in fp8 mode. fp8 can significantly reduce memory consumption but may impact output quality.
222
+
223
+ `--fp8_scaled` can be specified in addition to `--fp8` to run the model in fp8 weights optimization. This increases memory consumption and speed slightly but improves output quality. See [here](advanced_config.md#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化) for details.
224
+
225
+ `--fp8_fast` option is also available for faster inference on RTX 40x0 GPUs. This option requires `--fp8_scaled` option. **This option seems to degrade the output quality.**
226
+
227
  `--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
228
 
229
  `--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
230
 
231
+ `--flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
232
 
233
+ `--guidance_scale` can be used to specify the guidance scale for classifier free guidance (default 5.0).
234
 
235
  `--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
236
 
237
  `--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
238
 
239
+ `--compile` enables torch.compile. See [here](/README.md#inference) for details.
240
+
241
+ `--trim_tail_frames` can be used to trim the tail frames when saving. The default is 0.
242
+
243
+ `--cfg_skip_mode` specifies the mode for skipping CFG in different steps. The default is `none` (all steps).`--cfg_apply_ratio` specifies the ratio of steps where CFG is applied. See below for details.
244
+
245
+ `--include_patterns` and `--exclude_patterns` can be used to specify which LoRA modules to apply or exclude during training. If not specified, all modules are applied by default. These options accept regular expressions.
246
+
247
+ `--include_patterns` specifies the modules to be applied, and `--exclude_patterns` specifies the modules to be excluded. The regular expression is matched against the LoRA key name, and include takes precedence.
248
+
249
+ The key name to be searched is in sd-scripts format (`lora_unet_<module_name with dot replaced by _>`). For example, `lora_unet_blocks_9_cross_attn_k`.
250
+
251
+ For example, if you specify `--exclude_patterns "blocks_[23]\d_"`, it will exclude modules containing `blocks_20` to `blocks_39`. If you specify `--include_patterns "cross_attn" --exclude_patterns "blocks_(0|1|2|3|4)_"`, it will apply LoRA to modules containing `cross_attn` and not containing `blocks_0` to `blocks_4`.
252
+
253
+ If you specify multiple LoRA weights, please specify them with multiple arguments. For example: `--include_patterns "cross_attn" ".*" --exclude_patterns "dummy_do_not_exclude" "blocks_(0|1|2|3|4)"`. `".*"` is a regex that matches everything. `dummy_do_not_exclude` is a dummy regex that does not match anything.
254
+
255
+ `--cpu_noise` generates initial noise on the CPU. This may result in the same results as ComfyUI with the same seed (depending on other settings).
256
+
257
+ If you are using the Fun Control model, specify the control video with `--control_path`. You can specify a video file or a folder containing multiple image files. The number of frames in the video file (or the number of images) should be at least the number specified in `--video_length` (plus 1 frame if you specify `--end_image_path`).
258
+
259
+ Please try to match the aspect ratio of the control video with the aspect ratio specified in `--video_size` (there may be some deviation from the initial image of I2V due to the use of bucketing processing).
260
+
261
  Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
262
 
263
  <details>
264
  <summary>日本語</summary>
265
+ `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (これらはWan2.1公式モデル)、`t2v-1.3B-FC`, `t2v-14B-FC`, `i2v-14B-FC`(Wan2.1-Fun Controlモデル)を指定します。
266
 
267
  `--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
268
 
269
+ `--fp8` を指定するとDiTモデルをfp8形式で実行します。fp8はメモリ消費を大幅に削減できますが、出力品質に影響を与える可能性があります。
270
+
271
+ `--fp8_scaled` を `--fp8` と併用すると、fp8への重み量子化を行います。メモリ消費と速度はわずかに悪化しますが、出力品質が向上します。詳しくは[こちら](advanced_config.md#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化)を参照してください。
272
+
273
+ `--fp8_fast` オプションはRTX 40x0 GPUでの高速推論に使用されるオプションです。このオプションは `--fp8_scaled` オプションが必要です。**出力品質が劣化するようです。**
274
+
275
  `--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
276
 
277
  `--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
 
284
 
285
  `--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
286
 
287
+ `--compile`でtorch.compileを有効にします。詳細については[こちら](/README.md#inference)を参照してください。
288
+
289
+ `--trim_tail_frames` で保存時に末尾のフレームをトリミングできます。デフォルトは0です。
290
+
291
+ `--cfg_skip_mode` は異なるステップでCFGをスキップするモードを指定します。デフォルトは `none`(全ステップ)。`--cfg_apply_ratio` はCFGが適用されるステップの割合を指定します。詳細は後述します。
292
+
293
+ LoRAのどのモジュールを適用するかを、`--include_patterns`と`--exclude_patterns`で指定できます(未指定時・デフォルトは全モジュール適用されます
294
+ )。これらのオプションには、正規表現を指定します。`--include_patterns`は適用するモジュール、`--exclude_patterns`は適用しないモジュールを指定します。正規表現がLoRAのキー名に含まれるかどうかで判断され、includeが優先されます。
295
+
296
+ 検索対象となるキー名は sd-scripts 形式(`lora_unet_<モジュール名のドットを_に置換したもの>`)です。例:`lora_unet_blocks_9_cross_attn_k`
297
+
298
+ たとえば `--exclude_patterns "blocks_[23]\d_"`のみを指定すると、`blocks_20`から`blocks_39`を含むモジュールが除外されます。`--include_patterns "cross_attn" --exclude_patterns "blocks_(0|1|2|3|4)_"`のようにincludeとexcludeを指定すると、`cross_attn`を含むモジュールで、かつ`blocks_0`から`blocks_4`を含まないモジュールにLoRAが適用されます。
299
+
300
+ 複数のLoRAの重みを指定する場合は、複数個の引数で指定してください。例:`--include_patterns "cross_attn" ".*" --exclude_patterns "dummy_do_not_exclude" "blocks_(0|1|2|3|4)"` `".*"`は全てにマッチする正規表現です。`dummy_do_not_exclude`は何にもマッチしないダミーの正規表現です。
301
+
302
+ `--cpu_noise`を指定すると初期ノイズをCPUで生成します。これにより同一seed時の結果がComfyUIと同じになる可能性があります(他の設定にもよります)。
303
+
304
+ Fun Controlモデルを使用する場合は、`--control_path`で制御用の映像を指定します。動画ファイル、または複数枚の画像ファイルを含んだフォルダを指定できます。動画ファイルのフレーム数(または画像の枚数)は、`--video_length`で指定したフレーム数以上にしてください(後述の`--end_image_path`を指定した場合は、さらに+1フレーム)。
305
+
306
+ 制御用の映像のアスペクト比は、`--video_size`で指定したアスペクト比とできるかぎり合わせてください(bucketingの処理を流用しているためI2Vの初期画像とズレる場合があります)。
307
+
308
  その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
309
  </details>
310
 
311
+ #### CFG Skip Mode / CFGスキップモード
312
+
313
+ These options allow you to balance generation speed against prompt accuracy. More skipped steps results in faster generation with potential quality degradation.
314
+
315
+ Setting `--cfg_apply_ratio` to 0.5 speeds up the denoising loop by up to 25%.
316
+
317
+ `--cfg_skip_mode` specified one of the following modes:
318
+
319
+ - `early`: Skips CFG in early steps for faster generation, applying guidance mainly in later refinement steps
320
+ - `late`: Skips CFG in later steps, applying guidance during initial structure formation
321
+ - `middle`: Skips CFG in middle steps, applying guidance in both early and later steps
322
+ - `early_late`: Skips CFG in both early and late steps, applying only in middle steps
323
+ - `alternate`: Applies CFG in alternate steps based on the specified ratio
324
+ - `none`: Applies CFG at all steps (default)
325
+
326
+ `--cfg_apply_ratio` specifies a value from 0.0 to 1.0 controlling the proportion of steps where CFG is applied. For example, setting 0.5 means CFG will be applied in only 50% of the steps.
327
+
328
+ If num_steps is 10, the following table shows the steps where CFG is applied based on the `--cfg_skip_mode` option (A means CFG is applied, S means it is skipped, `--cfg_apply_ratio` is 0.6):
329
+
330
+ | skip mode | CFG apply pattern |
331
+ |---|---|
332
+ | early | SSSSAAAAAA |
333
+ | late | AAAAAASSSS |
334
+ | middle | AAASSSSAAA |
335
+ | early_late | SSAAAAAASS |
336
+ | alternate | SASASAASAS |
337
+
338
+ The appropriate settings are unknown, but you may want to try `late` or `early_late` mode with a ratio of around 0.3 to 0.5.
339
+ <details>
340
+ <summary>日本語</summary>
341
+ これらのオプションは、生成速度とプロンプトの精度のバランスを取ることができます。スキップされるステップが多いほど、生成速度が速くなりますが、品質が低下する可能性があります。
342
+
343
+ ratioに0.5を指定���ることで、デノイジングのループが最大25%程度、高速化されます。
344
+
345
+ `--cfg_skip_mode` は次のモードのいずれかを指定します:
346
+
347
+ - `early`:初期のステップでCFGをスキップして、主に終盤の精細化のステップで適用します
348
+ - `late`:終盤のステップでCFGをスキップし、初期の構造が決まる段階で適用します
349
+ - `middle`:中間のステップでCFGをスキップし、初期と終盤のステップの両方で適用します
350
+ - `early_late`:初期と終盤のステップの両方でCFGをスキップし、中間のステップのみ適用します
351
+ - `alternate`:指定された割合に基づいてCFGを適用します
352
+
353
+ `--cfg_apply_ratio` は、CFGが適用されるステップの割合を0.0から1.0の値で指定します。たとえば、0.5に設定すると、CFGはステップの50%のみで適用されます。
354
+
355
+ 具体的なパターンは上のテーブルを参照してください。
356
+
357
+ 適切な設定は不明ですが、モードは`late`または`early_late`、ratioは0.3~0.5程度から試してみると良いかもしれません。
358
+ </details>
359
+
360
+ #### Skip Layer Guidance
361
+
362
+ Skip Layer Guidance is a feature that uses the output of a model with some blocks skipped as the unconditional output of classifier free guidance. It was originally proposed in [SD 3.5](https://github.com/comfyanonymous/ComfyUI/pull/5404) and first applied in Wan2GP in [this PR](https://github.com/deepbeepmeep/Wan2GP/pull/61). It may improve the quality of generated videos.
363
+
364
+ The implementation of SD 3.5 is [here](https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py), and the implementation of Wan2GP (the PR mentioned above) has some different specifications. This inference script allows you to choose between the two methods.
365
+
366
+ *The SD3.5 method applies slg output in addition to cond and uncond (slows down the speed). The Wan2GP method uses only cond and slg output.*
367
+
368
+ The following arguments are available:
369
+
370
+ - `--slg_mode`: Specifies the SLG mode. `original` for SD 3.5 method, `uncond` for Wan2GP method. Default is None (no SLG).
371
+ - `--slg_layers`: Specifies the indices of the blocks (layers) to skip in SLG, separated by commas. Example: `--slg_layers 4,5,6`. Default is empty (no skip). If this option is not specified, `--slg_mode` is ignored.
372
+ - `--slg_scale`: Specifies the scale of SLG when `original`. Default is 3.0.
373
+ - `--slg_start`: Specifies the start step of SLG application in inference steps from 0.0 to 1.0. Default is 0.0 (applied from the beginning).
374
+ - `--slg_end`: Specifies the end step of SLG application in inference steps from 0.0 to 1.0. Default is 0.3 (applied up to 30% from the beginning).
375
+
376
+ Appropriate settings are unknown, but you may want to try `original` mode with a scale of around 3.0 and a start ratio of 0.0 and an end ratio of 0.5, with layers 4, 5, and 6 skipped.
377
+
378
+ <details>
379
+ <summary>日本語</summary>
380
+ Skip Layer Guidanceは、一部のblockをスキップしたモデル出力をclassifier free guidanceのunconditional出力に使用する機能です。元々は[SD 3.5](https://github.com/comfyanonymous/ComfyUI/pull/5404)で提案されたもので、Wan2.1には[Wan2GPのこちらのPR](https://github.com/deepbeepmeep/Wan2GP/pull/61)で初めて適用されました。生成動画の品質が向上する可能性があります。
381
+
382
+ SD 3.5の実装は[こちら](https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py)で、Wan2GPの実装(前述のPR)は一部仕様が異なります。この推論スクリプトでは両者の方式を選択できるようになっています。
383
+
384
+ ※SD3.5方式はcondとuncondに加えてslg outputを適用します(速度が低下します)。Wan2GP方式はcondとslg outputのみを使用します。
385
+
386
+ 以下の引数があります。
387
+
388
+ - `--slg_mode`:SLGのモードを指定します。`original`でSD 3.5の方式、`uncond`でWan2GPの方式です。デフォルトはNoneで、SLGを使用しません。
389
+ - `--slg_layers`:SLGでスキップするblock (layer)のインデクスをカンマ区切りで指定します。例:`--slg_layers 4,5,6`。デフォルトは空(スキップしない)です。このオプションを指定しないと`--slg_mode`は無視されます。
390
+ - `--slg_scale`:`original`のときのSLGのスケールを指定します。デフォルトは3.0です。
391
+ - `--slg_start`:推論ステップのSLG適用開始ステップを0.0から1.0の割合で指定します。デフォルトは0.0です(最初から適用)。
392
+ - `--slg_end`:推論ステップのSLG適用終了ステップを0.0から1.0の割合で指定します。デフォルトは0.3です(最初から30%まで適用)。
393
+
394
+ 適切な設定は不明ですが、`original`モードでスケールを3.0程度、開始割合を0.0、終了割合を0.5程度に設定し、4, 5, 6のlayerをスキップする設定から始めると良いかもしれません。
395
+ </details>
396
+
397
  ### I2V Inference / I2V推論
398
 
399
  The following is an example of I2V inference (input as a single line):
 
408
 
409
  Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
410
 
411
+ `--end_image_path` can be used to specify the end image. This option is experimental. When this option is specified, the saved video will be slightly longer than the specified number of frames and will have noise, so it is recommended to specify `--trim_tail_frames 3` to trim the tail frames.
412
+
413
+ You can also use the Fun Control model for I2V inference. Specify the control video with `--control_path`.
414
+
415
  Other options are same as T2V inference.
416
 
417
  <details>
418
  <summary>日本語</summary>
419
  `--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
420
 
421
+ `--end_image_path` で終了画像を指定できます。このオプションは実験的なものです。このオプションを指定すると、保存される動画が指定フレーム数よりもやや多くなり、かつノイズが乗るため、`--trim_tail_frames 3` などを指定して末尾のフレームをトリミングすることをお勧めします。
422
+
423
+ I2V推論でもFun Controlモデルが使用できます。`--control_path` で制御用の映像を指定します。
424
+
425
  その他のオプションはT2V推論と同じです。
426
  </details>
427
+
428
+ ### New Batch and Interactive Modes / 新しいバッチモードとインタラクティブモード
429
+
430
+ In addition to single video generation, Wan 2.1 now supports batch generation from file and interactive prompt input:
431
+
432
+ #### Batch Mode from File / ファイルからのバッチモード
433
+
434
+ Generate multiple videos from prompts stored in a text file:
435
+
436
+ ```bash
437
+ python wan_generate_video.py --from_file prompts.txt --task t2v-14B
438
+ --dit path/to/model.safetensors --vae path/to/vae.safetensors
439
+ --t5 path/to/t5_model.pth --save_path output_directory
440
+ ```
441
+
442
+ The prompts file format:
443
+ - One prompt per line
444
+ - Empty lines and lines starting with # are ignored (comments)
445
+ - Each line can include prompt-specific parameters using command-line style format:
446
+
447
+ ```
448
+ A beautiful sunset over mountains --w 832 --h 480 --f 81 --d 42 --s 20
449
+ A busy city street at night --w 480 --h 832 --g 7.5 --n low quality, blurry
450
+ ```
451
+
452
+ Supported inline parameters (if ommitted, default values from the command line are used):
453
+ - `--w`: Width
454
+ - `--h`: Height
455
+ - `--f`: Frame count
456
+ - `--d`: Seed
457
+ - `--s`: Inference steps
458
+ - `--g` or `--l`: Guidance scale
459
+ - `--fs`: Flow shift
460
+ - `--i`: Image path (for I2V)
461
+ - `--cn`: Control path (for Fun Control)
462
+ - `--n`: Negative prompt
463
+
464
+ In batch mode, models are loaded once and reused for all prompts, significantly improving overall generation time compared to multiple single runs.
465
+
466
+ #### Interactive Mode / インタラクティブモード
467
+
468
+ Interactive command-line interface for entering prompts:
469
+
470
+ ```bash
471
+ python wan_generate_video.py --interactive --task t2v-14B
472
+ --dit path/to/model.safetensors --vae path/to/vae.safetensors
473
+ --t5 path/to/t5_model.pth --save_path output_directory
474
+ ```
475
+
476
+ In interactive mode:
477
+ - Enter prompts directly at the command line
478
+ - Use the same inline parameter format as batch mode
479
+ - Use Ctrl+D (or Ctrl+Z on Windows) to exit
480
+ - Models remain loaded between generations for efficiency
481
+
482
+ <details>
483
+ <summary>日本語</summary>
484
+ 単一動画の生成に加えて、Wan 2.1は現在、ファイルからのバッチ生成とインタラクティブなプロンプト入力をサポートしています。
485
+
486
+ #### ファイルからのバッチモード
487
+
488
+ テキストファイルに保存されたプロンプトから複数の動画を生成します:
489
+
490
+ ```bash
491
+ python wan_generate_video.py --from_file prompts.txt --task t2v-14B
492
+ --dit path/to/model.safetensors --vae path/to/vae.safetensors
493
+ --t5 path/to/t5_model.pth --save_path output_directory
494
+ ```
495
+
496
+ プロンプトファイルの形式:
497
+ - 1行に1つのプロンプト
498
+ - 空行や#で始まる行は無視されます(コメント)
499
+ - 各行にはコマンドライン形式でプロンプト固有のパラメータを含めることができます:
500
+
501
+ サポートされているインラインパラメータ(省略した場合、コマンドラインのデフォルト値が使用されます)
502
+ - `--w`: 幅
503
+ - `--h`: 高さ
504
+ - `--f`: フレーム数
505
+ - `--d`: シード
506
+ - `--s`: 推論ステップ
507
+ - `--g` または `--l`: ガイダンススケール
508
+ - `--fs`: フローシフト
509
+ - `--i`: 画像パス(I2V用)
510
+ - `--cn`: コントロールパス(Fun Control用)
511
+ - `--n`: ネガティブプロンプト
512
+
513
+ バッチモードでは、モデルは一度だけロードされ、すべてのプロンプトで再利用されるため、複数回の単一実行と比較して全体的な生成時間が大幅に改善されます。
514
+
515
+ #### インタラクティブモード
516
+
517
+ プロンプトを入力するためのインタラクティブなコマンドラインインターフェース:
518
+
519
+ ```bash
520
+ python wan_generate_video.py --interactive --task t2v-14B
521
+ --dit path/to/model.safetensors --vae path/to/vae.safetensors
522
+ --t5 path/to/t5_model.pth --save_path output_directory
523
+ ```
524
+
525
+ インタラクティブモードでは:
526
+ - コマンドラインで直接プロンプトを入力
527
+ - バッチモードと同じインラインパラメータ形式を使用
528
+ - 終了するには Ctrl+D (Windowsでは Ctrl+Z) を使用
529
+ - 効率のため、モデルは生成間で読み込まれたままになります
530
+ </details>
531
+
fpack_cache_latents.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from typing import List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from tqdm import tqdm
11
+ from transformers import SiglipImageProcessor, SiglipVisionModel
12
+ from PIL import Image
13
+
14
+ from dataset import config_utils
15
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
16
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache_framepack, ARCHITECTURE_FRAMEPACK
17
+ from frame_pack import hunyuan
18
+ from frame_pack.framepack_utils import load_image_encoders, load_vae
19
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
20
+ from frame_pack.clip_vision import hf_clip_vision_encode
21
+ import cache_latents
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ def encode_and_save_batch(
28
+ vae: AutoencoderKLCausal3D,
29
+ feature_extractor: SiglipImageProcessor,
30
+ image_encoder: SiglipVisionModel,
31
+ batch: List[ItemInfo],
32
+ vanilla_sampling: bool = False,
33
+ one_frame: bool = False,
34
+ one_frame_no_2x: bool = False,
35
+ one_frame_no_4x: bool = False,
36
+ ):
37
+ """Encode a batch of original RGB videos and save FramePack section caches."""
38
+ if one_frame:
39
+ encode_and_save_batch_one_frame(
40
+ vae, feature_extractor, image_encoder, batch, vanilla_sampling, one_frame_no_2x, one_frame_no_4x
41
+ )
42
+ return
43
+
44
+ latent_window_size = batch[0].fp_latent_window_size # all items should have the same window size
45
+
46
+ # Stack batch into tensor (B,C,F,H,W) in RGB order
47
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
48
+ if len(contents.shape) == 4:
49
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
50
+
51
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
52
+ contents = contents.to(vae.device, dtype=vae.dtype)
53
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
54
+
55
+ height, width = contents.shape[3], contents.shape[4]
56
+ if height < 8 or width < 8:
57
+ item = batch[0] # other items should have the same size
58
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
59
+
60
+ # calculate latent frame count from original frame count (4n+1)
61
+ latent_f = (batch[0].frame_count - 1) // 4 + 1
62
+
63
+ # calculate the total number of sections (excluding the first frame, divided by window size)
64
+ total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
65
+ if total_latent_sections < 1:
66
+ min_frames_needed = latent_window_size * 4 + 1
67
+ raise ValueError(
68
+ f"Not enough frames for FramePack: {batch[0].frame_count} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size+1} latent frames)"
69
+ )
70
+
71
+ # actual latent frame count (aligned to section boundaries)
72
+ latent_f_aligned = total_latent_sections * latent_window_size + 1 if not one_frame else 1
73
+
74
+ # actual video frame count
75
+ frame_count_aligned = (latent_f_aligned - 1) * 4 + 1
76
+ if frame_count_aligned != batch[0].frame_count:
77
+ logger.info(
78
+ f"Frame count mismatch: required={frame_count_aligned} != actual={batch[0].frame_count}, trimming to {frame_count_aligned}"
79
+ )
80
+ contents = contents[:, :, :frame_count_aligned, :, :]
81
+
82
+ latent_f = latent_f_aligned # Update to the aligned value
83
+
84
+ # VAE encode (list of tensor -> stack)
85
+ latents = hunyuan.vae_encode(contents, vae) # include scaling factor
86
+ latents = latents.to("cpu") # (B, C, latent_f, H/8, W/8)
87
+
88
+ # Vision encoding per‑item (once)
89
+ images = np.stack([item.content[0] for item in batch], axis=0) # B, H, W, C
90
+
91
+ # encode image with image encoder
92
+ image_embeddings = []
93
+ with torch.no_grad():
94
+ for image in images:
95
+ image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
96
+ image_embeddings.append(image_encoder_output.last_hidden_state)
97
+ image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
98
+ image_embeddings = image_embeddings.to("cpu") # Save memory
99
+
100
+ if not vanilla_sampling:
101
+ # padding is reversed for inference (future to past)
102
+ latent_paddings = list(reversed(range(total_latent_sections)))
103
+ # Note: The padding trick for inference. See the paper for details.
104
+ if total_latent_sections > 4:
105
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
106
+
107
+ for b, item in enumerate(batch):
108
+ original_latent_cache_path = item.latent_cache_path
109
+ video_lat = latents[b : b + 1] # keep batch dim, 1, C, F, H, W
110
+
111
+ # emulate inference step (history latents)
112
+ # Note: In inference, history_latents stores *generated* future latents.
113
+ # Here, for caching, we just need its shape and type for clean_* tensors.
114
+ # The actual content doesn't matter much as clean_* will be overwritten.
115
+ history_latents = torch.zeros(
116
+ (1, video_lat.shape[1], 1 + 2 + 16, video_lat.shape[3], video_lat.shape[4]), dtype=video_lat.dtype
117
+ ) # C=16 for HY
118
+
119
+ latent_f_index = latent_f - latent_window_size # Start from the last section
120
+ section_index = total_latent_sections - 1
121
+
122
+ for latent_padding in latent_paddings:
123
+ is_last_section = section_index == 0 # the last section in inference order == the first section in time
124
+ latent_padding_size = latent_padding * latent_window_size
125
+ if is_last_section:
126
+ assert latent_f_index == 1, "Last section should be starting from frame 1"
127
+
128
+ # indices generation (same as inference)
129
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
130
+ (
131
+ clean_latent_indices_pre, # Index for start_latent
132
+ blank_indices, # Indices for padding (future context in inference)
133
+ latent_indices, # Indices for the target latents to predict
134
+ clean_latent_indices_post, # Index for the most recent history frame
135
+ clean_latent_2x_indices, # Indices for the next 2 history frames
136
+ clean_latent_4x_indices, # Indices for the next 16 history frames
137
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
138
+
139
+ # Indices for clean_latents (start + recent history)
140
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
141
+
142
+ # clean latents preparation (emulating inference)
143
+ clean_latents_pre = video_lat[:, :, 0:1, :, :] # Always the first frame (start_latent)
144
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
145
+ [1, 2, 16], dim=2
146
+ )
147
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
148
+
149
+ # Target latents for this section (ground truth)
150
+ target_latents = video_lat[:, :, latent_f_index : latent_f_index + latent_window_size, :, :]
151
+
152
+ # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
153
+ item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
154
+ save_latent_cache_framepack(
155
+ item_info=item,
156
+ latent=target_latents.squeeze(0), # Ground truth for this section
157
+ latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
158
+ clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
159
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
160
+ clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
161
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
162
+ clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
163
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
164
+ image_embeddings=image_embeddings[b],
165
+ )
166
+
167
+ if is_last_section: # If this was the first section generated in inference (time=0)
168
+ # History gets the start frame + the generated first section
169
+ generated_latents_for_history = video_lat[:, :, : latent_window_size + 1, :, :]
170
+ else:
171
+ # History gets the generated current section
172
+ generated_latents_for_history = target_latents # Use true latents as stand-in for generated
173
+
174
+ history_latents = torch.cat([generated_latents_for_history, history_latents], dim=2)
175
+
176
+ section_index -= 1
177
+ latent_f_index -= latent_window_size
178
+
179
+ else:
180
+ # Vanilla Sampling Logic
181
+ for b, item in enumerate(batch):
182
+ original_latent_cache_path = item.latent_cache_path
183
+ video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
184
+ img_emb = image_embeddings[b] # LEN, 1152
185
+
186
+ for section_index in range(total_latent_sections):
187
+ target_start_f = section_index * latent_window_size + 1
188
+ target_end_f = target_start_f + latent_window_size
189
+ target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
190
+ start_latent = video_lat[:, :, 0:1, :, :]
191
+
192
+ # Clean latents preparation (Vanilla)
193
+ clean_latents_total_count = 1 + 2 + 16
194
+ history_latents = torch.zeros(
195
+ size=(1, 16, clean_latents_total_count, video_lat.shape[-2], video_lat.shape[-1]),
196
+ device=video_lat.device,
197
+ dtype=video_lat.dtype,
198
+ )
199
+
200
+ history_start_f = 0
201
+ video_start_f = target_start_f - clean_latents_total_count
202
+ copy_count = clean_latents_total_count
203
+ if video_start_f < 0:
204
+ history_start_f = -video_start_f
205
+ copy_count = clean_latents_total_count - history_start_f
206
+ video_start_f = 0
207
+ if copy_count > 0:
208
+ history_latents[:, :, history_start_f:] = video_lat[:, :, video_start_f : video_start_f + copy_count, :, :]
209
+
210
+ # indices generation (Vanilla): copy from FramePack-F1
211
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
212
+ (
213
+ clean_latent_indices_start,
214
+ clean_latent_4x_indices,
215
+ clean_latent_2x_indices,
216
+ clean_latent_1x_indices,
217
+ latent_indices,
218
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
219
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
220
+
221
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2)
222
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
223
+
224
+ # Save cache
225
+ item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
226
+ save_latent_cache_framepack(
227
+ item_info=item,
228
+ latent=target_latents.squeeze(0),
229
+ latent_indices=latent_indices.squeeze(0), # Indices for target section i
230
+ clean_latents=clean_latents.squeeze(0), # Past clean frames
231
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for clean_latents_pre/post
232
+ clean_latents_2x=clean_latents_2x.squeeze(0), # Past clean frames (2x)
233
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for clean_latents_2x
234
+ clean_latents_4x=clean_latents_4x.squeeze(0), # Past clean frames (4x)
235
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for clean_latents_4x
236
+ image_embeddings=img_emb,
237
+ # Note: We don't explicitly save past_offset_indices,
238
+ # but its size influences the absolute values in other indices.
239
+ )
240
+
241
+
242
+ def encode_and_save_batch_one_frame(
243
+ vae: AutoencoderKLCausal3D,
244
+ feature_extractor: SiglipImageProcessor,
245
+ image_encoder: SiglipVisionModel,
246
+ batch: List[ItemInfo],
247
+ vanilla_sampling: bool = False,
248
+ one_frame_no_2x: bool = False,
249
+ one_frame_no_4x: bool = False,
250
+ ):
251
+ # item.content: target image (H, W, C)
252
+ # item.control_content: list of images (H, W, C)
253
+
254
+ # Stack batch into tensor (B,F,H,W,C) in RGB order. The numbers of control content for each item are the same.
255
+ contents = []
256
+ content_masks: list[list[Optional[torch.Tensor]]] = []
257
+ for item in batch:
258
+ item_contents = item.control_content + [item.content]
259
+
260
+ item_masks = []
261
+ for i, c in enumerate(item_contents):
262
+ if c.shape[-1] == 4: # RGBA
263
+ item_contents[i] = c[..., :3] # remove alpha channel from content
264
+
265
+ alpha = c[..., 3] # extract alpha channel
266
+ mask_image = Image.fromarray(alpha, mode="L")
267
+ width, height = mask_image.size
268
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
269
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
270
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
271
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
272
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
273
+ mask_image = mask_image.to(torch.float32)
274
+ content_mask = mask_image
275
+ else:
276
+ content_mask = None
277
+
278
+ item_masks.append(content_mask)
279
+
280
+ item_contents = [torch.from_numpy(c) for c in item_contents]
281
+ contents.append(torch.stack(item_contents, dim=0)) # list of [F, H, W, C]
282
+ content_masks.append(item_masks)
283
+
284
+ contents = torch.stack(contents, dim=0) # B, F, H, W, C. F is control frames + target frame
285
+
286
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
287
+ contents = contents.to(vae.device, dtype=vae.dtype)
288
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
289
+
290
+ height, width = contents.shape[-2], contents.shape[-1]
291
+ if height < 8 or width < 8:
292
+ item = batch[0] # other items should have the same size
293
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
294
+
295
+ # VAE encode: we need to encode one frame at a time because VAE encoder has stride=4 for the time dimension except for the first frame.
296
+ latents = [hunyuan.vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])]
297
+ latents = torch.cat(latents, dim=2) # B, C, F, H/8, W/8
298
+
299
+ # apply alphas to latents
300
+ for b, item in enumerate(batch):
301
+ for i, content_mask in enumerate(content_masks[b]):
302
+ if content_mask is not None:
303
+ # apply mask to the latents
304
+ # print(f"Applying content mask for item {item.item_key}, frame {i}")
305
+ latents[b : b + 1, :, i : i + 1] *= content_mask
306
+
307
+ # Vision encoding per‑item (once): use control content because it is the start image
308
+ images = [item.control_content[0] for item in batch] # list of [H, W, C]
309
+
310
+ # encode image with image encoder
311
+ image_embeddings = []
312
+ with torch.no_grad():
313
+ for image in images:
314
+ image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
315
+ image_embeddings.append(image_encoder_output.last_hidden_state)
316
+ image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
317
+ image_embeddings = image_embeddings.to("cpu") # Save memory
318
+
319
+ # save cache for each item in the batch
320
+ for b, item in enumerate(batch):
321
+ # indices generation (same as inference): each item may have different clean_latent_indices, so we generate them per item
322
+ clean_latent_indices = item.fp_1f_clean_indices # list of indices for clean latents
323
+ if clean_latent_indices is None or len(clean_latent_indices) == 0:
324
+ logger.warning(
325
+ f"Item {item.item_key} has no clean_latent_indices defined, using default indices for one frame training."
326
+ )
327
+ clean_latent_indices = [0]
328
+
329
+ if not item.fp_1f_no_post:
330
+ clean_latent_indices = clean_latent_indices + [1 + item.fp_latent_window_size]
331
+ clean_latent_indices = torch.Tensor(clean_latent_indices).long() # N
332
+
333
+ latent_index = torch.Tensor([item.fp_1f_target_index]).long() # 1
334
+
335
+ # zero values is not needed to cache even if one_frame_no_2x or 4x is False
336
+ clean_latents_2x = None
337
+ clean_latents_4x = None
338
+
339
+ if one_frame_no_2x:
340
+ clean_latent_2x_indices = None
341
+ else:
342
+ index = 1 + item.fp_latent_window_size + 1
343
+ clean_latent_2x_indices = torch.arange(index, index + 2) # 2
344
+
345
+ if one_frame_no_4x:
346
+ clean_latent_4x_indices = None
347
+ else:
348
+ index = 1 + item.fp_latent_window_size + 1 + 2
349
+ clean_latent_4x_indices = torch.arange(index, index + 16) # 16
350
+
351
+ # clean latents preparation (emulating inference)
352
+ clean_latents = latents[b, :, :-1] # C, F, H, W
353
+ if not item.fp_1f_no_post:
354
+ # If zero post is enabled, we need to add a zero frame at the end
355
+ clean_latents = F.pad(clean_latents, (0, 0, 0, 0, 0, 1), value=0.0) # C, F+1, H, W
356
+
357
+ # Target latents for this section (ground truth)
358
+ target_latents = latents[b, :, -1:] # C, 1, H, W
359
+
360
+ print(f"Saving cache for item {item.item_key} at {item.latent_cache_path}. no_post: {item.fp_1f_no_post}")
361
+ print(f" Clean latent indices: {clean_latent_indices}, latent index: {latent_index}")
362
+ print(f" Clean latents: {clean_latents.shape}, target latents: {target_latents.shape}")
363
+ print(f" Clean latents 2x indices: {clean_latent_2x_indices}, clean latents 4x indices: {clean_latent_4x_indices}")
364
+ print(
365
+ f" Clean latents 2x: {clean_latents_2x.shape if clean_latents_2x is not None else 'None'}, "
366
+ f"Clean latents 4x: {clean_latents_4x.shape if clean_latents_4x is not None else 'None'}"
367
+ )
368
+ print(f" Image embeddings: {image_embeddings[b].shape}")
369
+
370
+ # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
371
+ save_latent_cache_framepack(
372
+ item_info=item,
373
+ latent=target_latents, # Ground truth for this section
374
+ latent_indices=latent_index, # Indices for the ground truth section
375
+ clean_latents=clean_latents, # Start frame + history placeholder
376
+ clean_latent_indices=clean_latent_indices, # Indices for start frame + history placeholder
377
+ clean_latents_2x=clean_latents_2x, # History placeholder
378
+ clean_latent_2x_indices=clean_latent_2x_indices, # Indices for history placeholder
379
+ clean_latents_4x=clean_latents_4x, # History placeholder
380
+ clean_latent_4x_indices=clean_latent_4x_indices, # Indices for history placeholder
381
+ image_embeddings=image_embeddings[b],
382
+ )
383
+
384
+
385
+ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
386
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
387
+ parser.add_argument(
388
+ "--f1",
389
+ action="store_true",
390
+ help="Generate cache for F1 model (vanilla (autoregressive) sampling) instead of Inverted anti-drifting (plain FramePack)",
391
+ )
392
+ parser.add_argument(
393
+ "--one_frame",
394
+ action="store_true",
395
+ help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
396
+ )
397
+ parser.add_argument(
398
+ "--one_frame_no_2x",
399
+ action="store_true",
400
+ help="Do not use clean_latents_2x and clean_latent_2x_indices for one frame training.",
401
+ )
402
+ parser.add_argument(
403
+ "--one_frame_no_4x",
404
+ action="store_true",
405
+ help="Do not use clean_latents_4x and clean_latent_4x_indices for one frame training.",
406
+ )
407
+ return parser
408
+
409
+
410
+ def main(args: argparse.Namespace):
411
+ device = args.device if hasattr(args, "device") and args.device else ("cuda" if torch.cuda.is_available() else "cpu")
412
+ device = torch.device(device)
413
+
414
+ # Load dataset config
415
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
416
+ logger.info(f"Load dataset config from {args.dataset_config}")
417
+ user_config = config_utils.load_user_config(args.dataset_config)
418
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
419
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
420
+
421
+ datasets = train_dataset_group.datasets
422
+
423
+ if args.debug_mode is not None:
424
+ cache_latents.show_datasets(
425
+ datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
426
+ )
427
+ return
428
+
429
+ assert args.vae is not None, "vae checkpoint is required"
430
+
431
+ logger.info(f"Loading VAE model from {args.vae}")
432
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device=device)
433
+ vae.to(device)
434
+
435
+ logger.info(f"Loading image encoder from {args.image_encoder}")
436
+ feature_extractor, image_encoder = load_image_encoders(args)
437
+ image_encoder.eval()
438
+ image_encoder.to(device)
439
+
440
+ logger.info(f"Cache generation mode: {'Vanilla Sampling' if args.f1 else 'Inference Emulation'}")
441
+
442
+ # encoding closure
443
+ def encode(batch: List[ItemInfo]):
444
+ encode_and_save_batch(
445
+ vae, feature_extractor, image_encoder, batch, args.f1, args.one_frame, args.one_frame_no_2x, args.one_frame_no_4x
446
+ )
447
+
448
+ # reuse core loop from cache_latents with no change
449
+ encode_datasets_framepack(datasets, encode, args)
450
+
451
+
452
+ def append_section_idx_to_latent_cache_path(latent_cache_path: str, section_idx: int) -> str:
453
+ tokens = latent_cache_path.split("_")
454
+ tokens[-3] = f"{tokens[-3]}-{section_idx:04d}" # append section index to "frame_pos-count"
455
+ return "_".join(tokens)
456
+
457
+
458
+ def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
459
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
460
+ for i, dataset in enumerate(datasets):
461
+ logger.info(f"Encoding dataset [{i}]")
462
+ all_latent_cache_paths = []
463
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
464
+ batch: list[ItemInfo] = batch # type: ignore
465
+
466
+ # latent_cache_path is "{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
467
+ # For video dataset,we expand it to "{basename}_{section_idx:04d}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
468
+ filtered_batch = []
469
+ for item in batch:
470
+ if item.frame_count is None:
471
+ # image dataset
472
+ all_latent_cache_paths.append(item.latent_cache_path)
473
+ all_existing = os.path.exists(item.latent_cache_path)
474
+ else:
475
+ latent_f = (item.frame_count - 1) // 4 + 1
476
+ num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size)) # min 1 section
477
+ all_existing = True
478
+ for sec in range(num_sections):
479
+ p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
480
+ all_latent_cache_paths.append(p)
481
+ all_existing = all_existing and os.path.exists(p)
482
+
483
+ if not all_existing: # if any section cache is missing
484
+ filtered_batch.append(item)
485
+
486
+ if args.skip_existing:
487
+ if len(filtered_batch) == 0: # all sections exist
488
+ logger.info(f"All sections exist for {batch[0].item_key}, skipping")
489
+ continue
490
+ batch = filtered_batch # update batch to only missing sections
491
+
492
+ bs = args.batch_size if args.batch_size is not None else len(batch)
493
+ for i in range(0, len(batch), bs):
494
+ encode(batch[i : i + bs])
495
+
496
+ # normalize paths
497
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
498
+ all_latent_cache_paths = set(all_latent_cache_paths)
499
+
500
+ # remove old cache files not in the dataset
501
+ all_cache_files = dataset.get_all_latent_cache_files()
502
+ for cache_file in all_cache_files:
503
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
504
+ if args.keep_cache:
505
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
506
+ else:
507
+ os.remove(cache_file)
508
+ logger.info(f"Removed old cache file: {cache_file}")
509
+
510
+
511
+ if __name__ == "__main__":
512
+ parser = cache_latents.setup_parser_common()
513
+ parser = cache_latents.hv_setup_parser(parser) # VAE
514
+ parser = framepack_setup_parser(parser)
515
+
516
+ args = parser.parse_args()
517
+
518
+ if args.vae_dtype is not None:
519
+ raise ValueError("VAE dtype is not supported in FramePack")
520
+ # if args.batch_size != 1:
521
+ # args.batch_size = 1
522
+ # logger.info("Batch size is set to 1 for FramePack.")
523
+
524
+ main(args)
fpack_cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import LlamaTokenizerFast, LlamaModel, CLIPTokenizer, CLIPTextModel
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ItemInfo, save_text_encoder_output_cache_framepack
12
+ import cache_text_encoder_outputs
13
+ from frame_pack import hunyuan
14
+ from frame_pack.framepack_utils import load_text_encoder1, load_text_encoder2
15
+
16
+ import logging
17
+
18
+ from frame_pack.utils import crop_or_pad_yield_mask
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ def encode_and_save_batch(
25
+ tokenizer1: LlamaTokenizerFast,
26
+ text_encoder1: LlamaModel,
27
+ tokenizer2: CLIPTokenizer,
28
+ text_encoder2: CLIPTextModel,
29
+ batch: list[ItemInfo],
30
+ device: torch.device,
31
+ ):
32
+ prompts = [item.caption for item in batch]
33
+
34
+ # encode prompt
35
+ # FramePack's encode_prompt_conds only supports single prompt, so we need to encode each prompt separately
36
+ list_of_llama_vec = []
37
+ list_of_llama_attention_mask = []
38
+ list_of_clip_l_pooler = []
39
+ for prompt in prompts:
40
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
41
+ # llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompts, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
42
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
43
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
44
+
45
+ list_of_llama_vec.append(llama_vec.squeeze(0))
46
+ list_of_llama_attention_mask.append(llama_attention_mask.squeeze(0))
47
+ list_of_clip_l_pooler.append(clip_l_pooler.squeeze(0))
48
+
49
+ # save prompt cache
50
+ for item, llama_vec, llama_attention_mask, clip_l_pooler in zip(
51
+ batch, list_of_llama_vec, list_of_llama_attention_mask, list_of_clip_l_pooler
52
+ ):
53
+ # save llama_vec and clip_l_pooler to cache
54
+ save_text_encoder_output_cache_framepack(item, llama_vec, llama_attention_mask, clip_l_pooler)
55
+
56
+
57
+ def main(args):
58
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
59
+ device = torch.device(device)
60
+
61
+ # Load dataset config
62
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
63
+ logger.info(f"Load dataset config from {args.dataset_config}")
64
+ user_config = config_utils.load_user_config(args.dataset_config)
65
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
66
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
67
+
68
+ datasets = train_dataset_group.datasets
69
+
70
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
71
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
72
+
73
+ # load text encoder
74
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
75
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
76
+ text_encoder2.to(device)
77
+
78
+ # Encode with Text Encoders
79
+ logger.info("Encoding with Text Encoders")
80
+
81
+ def encode_for_text_encoder(batch: list[ItemInfo]):
82
+ encode_and_save_batch(tokenizer1, text_encoder1, tokenizer2, text_encoder2, batch, device)
83
+
84
+ cache_text_encoder_outputs.process_text_encoder_batches(
85
+ args.num_workers,
86
+ args.skip_existing,
87
+ args.batch_size,
88
+ datasets,
89
+ all_cache_files_for_dataset,
90
+ all_cache_paths_for_dataset,
91
+ encode_for_text_encoder,
92
+ )
93
+
94
+ # remove cache files not in dataset
95
+ cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
96
+
97
+
98
+ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
99
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
100
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
101
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
102
+ return parser
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = cache_text_encoder_outputs.setup_parser_common()
107
+ parser = framepack_setup_parser(parser)
108
+
109
+ args = parser.parse_args()
110
+ main(args)
fpack_generate_video.py ADDED
@@ -0,0 +1,1832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ import gc
4
+ import json
5
+ import random
6
+ import os
7
+ import re
8
+ import time
9
+ import math
10
+ import copy
11
+ from typing import Tuple, Optional, List, Union, Any, Dict
12
+
13
+ import torch
14
+ from safetensors.torch import load_file, save_file
15
+ from safetensors import safe_open
16
+ from PIL import Image
17
+ import cv2
18
+ import numpy as np
19
+ import torchvision.transforms.functional as TF
20
+ from transformers import LlamaModel
21
+ from tqdm import tqdm
22
+
23
+ from networks import lora_framepack
24
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
25
+ from frame_pack import hunyuan
26
+ from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
27
+ from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw
28
+ from frame_pack.bucket_tools import find_nearest_bucket
29
+ from frame_pack.clip_vision import hf_clip_vision_encode
30
+ from frame_pack.k_diffusion_hunyuan import sample_hunyuan
31
+ from dataset import image_video_dataset
32
+
33
+ try:
34
+ from lycoris.kohya import create_network_from_weights
35
+ except:
36
+ pass
37
+
38
+ from utils.device_utils import clean_memory_on_device
39
+ from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
40
+ from wan_generate_video import merge_lora_weights
41
+ from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders
42
+ from dataset.image_video_dataset import load_video
43
+
44
+ import logging
45
+
46
+ logger = logging.getLogger(__name__)
47
+ logging.basicConfig(level=logging.INFO)
48
+
49
+
50
+ class GenerationSettings:
51
+ def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None):
52
+ self.device = device
53
+ self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized
54
+
55
+
56
+ def parse_args() -> argparse.Namespace:
57
+ """parse command line arguments"""
58
+ parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
59
+
60
+ # WAN arguments
61
+ # parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
62
+ parser.add_argument(
63
+ "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
64
+ )
65
+
66
+ parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
67
+ parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
68
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path")
69
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path")
70
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path")
71
+ parser.add_argument("--f1", action="store_true", help="Use F1 sampling method")
72
+
73
+ # LoRA
74
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
75
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
76
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
77
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
78
+ parser.add_argument(
79
+ "--save_merged_model",
80
+ type=str,
81
+ default=None,
82
+ help="Save merged model to path. If specified, no inference will be performed.",
83
+ )
84
+
85
+ # inference
86
+ parser.add_argument(
87
+ "--prompt",
88
+ type=str,
89
+ default=None,
90
+ help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or "
91
+ "`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).",
92
+ )
93
+ parser.add_argument(
94
+ "--negative_prompt",
95
+ type=str,
96
+ default=None,
97
+ help="negative prompt for generation, default is empty string. should not change.",
98
+ )
99
+ parser.add_argument(
100
+ "--custom_system_prompt",
101
+ type=str,
102
+ default=None,
103
+ help="Custom system prompt for LLM. If specified, it will override the default system prompt. See hunyuan_model/text_encoder.py for the default system prompt.",
104
+ )
105
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
106
+ parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, default is 5.0 seconds")
107
+ parser.add_argument(
108
+ "--video_sections",
109
+ type=int,
110
+ default=None,
111
+ help="number of video sections, Default is None (auto calculate from video seconds)",
112
+ )
113
+ parser.add_argument(
114
+ "--one_frame_inference",
115
+ type=str,
116
+ default=None,
117
+ help="one frame inference, default is None, comma separated values from 'no_2x', 'no_4x', 'no_post', 'control_indices' and 'target_index'.",
118
+ )
119
+ parser.add_argument(
120
+ "--control_image_path", type=str, default=None, nargs="*", help="path to control (reference) image for one frame inference."
121
+ )
122
+ parser.add_argument(
123
+ "--control_image_mask_path",
124
+ type=str,
125
+ default=None,
126
+ nargs="*",
127
+ help="path to control (reference) image mask for one frame inference.",
128
+ )
129
+ parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
130
+ parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
131
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
132
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
133
+ # parser.add_argument(
134
+ # "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
135
+ # )
136
+ parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.")
137
+ parser.add_argument(
138
+ "--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0"
139
+ )
140
+ parser.add_argument(
141
+ "--guidance_scale",
142
+ type=float,
143
+ default=1.0,
144
+ help="Guidance scale for classifier free guidance. Default is 1.0 (no guidance), should not change.",
145
+ )
146
+ parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.")
147
+ # parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
148
+ parser.add_argument(
149
+ "--image_path",
150
+ type=str,
151
+ default=None,
152
+ help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
153
+ )
154
+ parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
155
+ parser.add_argument(
156
+ "--latent_paddings",
157
+ type=str,
158
+ default=None,
159
+ help="latent paddings for each section, comma separated values. default is None (FramePack default paddings)",
160
+ )
161
+ # parser.add_argument(
162
+ # "--control_path",
163
+ # type=str,
164
+ # default=None,
165
+ # help="path to control video for inference with controlnet. video file or directory with images",
166
+ # )
167
+ # parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
168
+
169
+ # # Flow Matching
170
+ # parser.add_argument(
171
+ # "--flow_shift",
172
+ # type=float,
173
+ # default=None,
174
+ # help="Shift factor for flow matching schedulers. Default depends on task.",
175
+ # )
176
+
177
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
178
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
179
+ # parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
180
+ parser.add_argument(
181
+ "--rope_scaling_factor", type=float, default=0.5, help="RoPE scaling factor for high resolution (H/W), default is 0.5"
182
+ )
183
+ parser.add_argument(
184
+ "--rope_scaling_timestep_threshold",
185
+ type=int,
186
+ default=None,
187
+ help="RoPE scaling timestep threshold, default is None (disable), if set, RoPE scaling will be applied only for timesteps >= threshold, around 800 is good starting point",
188
+ )
189
+
190
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
191
+ parser.add_argument(
192
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
193
+ )
194
+ parser.add_argument(
195
+ "--attn_mode",
196
+ type=str,
197
+ default="torch",
198
+ choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
199
+ help="attention mode",
200
+ )
201
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
202
+ parser.add_argument(
203
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
204
+ )
205
+ parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once")
206
+ parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
207
+ parser.add_argument(
208
+ "--output_type",
209
+ type=str,
210
+ default="video",
211
+ choices=["video", "images", "latent", "both", "latent_images"],
212
+ help="output type",
213
+ )
214
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
215
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
216
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
217
+ # parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
218
+ # parser.add_argument(
219
+ # "--compile_args",
220
+ # nargs=4,
221
+ # metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
222
+ # default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
223
+ # help="Torch.compile settings",
224
+ # )
225
+
226
+ # New arguments for batch and interactive modes
227
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
228
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
229
+
230
+ args = parser.parse_args()
231
+
232
+ # Validate arguments
233
+ if args.from_file and args.interactive:
234
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
235
+
236
+ if args.latent_path is None or len(args.latent_path) == 0:
237
+ if args.prompt is None and not args.from_file and not args.interactive:
238
+ raise ValueError("Either --prompt, --from_file or --interactive must be specified")
239
+
240
+ return args
241
+
242
+
243
+ def parse_prompt_line(line: str) -> Dict[str, Any]:
244
+ """Parse a prompt line into a dictionary of argument overrides
245
+
246
+ Args:
247
+ line: Prompt line with options
248
+
249
+ Returns:
250
+ Dict[str, Any]: Dictionary of argument overrides
251
+ """
252
+ # TODO common function with hv_train_network.line_to_prompt_dict
253
+ parts = line.split(" --")
254
+ prompt = parts[0].strip()
255
+
256
+ # Create dictionary of overrides
257
+ overrides = {"prompt": prompt}
258
+ # Initialize control_image_path and control_image_mask_path as a list to accommodate multiple paths
259
+ overrides["control_image_path"] = []
260
+ overrides["control_image_mask_path"] = []
261
+
262
+ for part in parts[1:]:
263
+ if not part.strip():
264
+ continue
265
+ option_parts = part.split(" ", 1)
266
+ option = option_parts[0].strip()
267
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
268
+
269
+ # Map options to argument names
270
+ if option == "w":
271
+ overrides["video_size_width"] = int(value)
272
+ elif option == "h":
273
+ overrides["video_size_height"] = int(value)
274
+ elif option == "f":
275
+ overrides["video_seconds"] = float(value)
276
+ elif option == "d":
277
+ overrides["seed"] = int(value)
278
+ elif option == "s":
279
+ overrides["infer_steps"] = int(value)
280
+ elif option == "g" or option == "l":
281
+ overrides["guidance_scale"] = float(value)
282
+ # elif option == "fs":
283
+ # overrides["flow_shift"] = float(value)
284
+ elif option == "i":
285
+ overrides["image_path"] = value
286
+ # elif option == "im":
287
+ # overrides["image_mask_path"] = value
288
+ # elif option == "cn":
289
+ # overrides["control_path"] = value
290
+ elif option == "n":
291
+ overrides["negative_prompt"] = value
292
+ elif option == "vs": # video_sections
293
+ overrides["video_sections"] = int(value)
294
+ elif option == "ei": # end_image_path
295
+ overrides["end_image_path"] = value
296
+ elif option == "ci": # control_image_path
297
+ overrides["control_image_path"].append(value)
298
+ elif option == "cim": # control_image_mask_path
299
+ overrides["control_image_mask_path"].append(value)
300
+ elif option == "of": # one_frame_inference
301
+ overrides["one_frame_inference"] = value
302
+
303
+ # If no control_image_path was provided, remove the empty list
304
+ if not overrides["control_image_path"]:
305
+ del overrides["control_image_path"]
306
+ if not overrides["control_image_mask_path"]:
307
+ del overrides["control_image_mask_path"]
308
+
309
+ return overrides
310
+
311
+
312
+ def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
313
+ """Apply overrides to args
314
+
315
+ Args:
316
+ args: Original arguments
317
+ overrides: Dictionary of overrides
318
+
319
+ Returns:
320
+ argparse.Namespace: New arguments with overrides applied
321
+ """
322
+ args_copy = copy.deepcopy(args)
323
+
324
+ for key, value in overrides.items():
325
+ if key == "video_size_width":
326
+ args_copy.video_size[1] = value
327
+ elif key == "video_size_height":
328
+ args_copy.video_size[0] = value
329
+ else:
330
+ setattr(args_copy, key, value)
331
+
332
+ return args_copy
333
+
334
+
335
+ def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
336
+ """Validate video size and length
337
+
338
+ Args:
339
+ args: command line arguments
340
+
341
+ Returns:
342
+ Tuple[int, int, float]: (height, width, video_seconds)
343
+ """
344
+ height = args.video_size[0]
345
+ width = args.video_size[1]
346
+
347
+ video_seconds = args.video_seconds
348
+ if args.video_sections is not None:
349
+ video_seconds = (args.video_sections * (args.latent_window_size * 4) + 1) / args.fps
350
+
351
+ if height % 8 != 0 or width % 8 != 0:
352
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
353
+
354
+ return height, width, video_seconds
355
+
356
+
357
+ # region DiT model
358
+
359
+
360
+ def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked:
361
+ """load DiT model
362
+
363
+ Args:
364
+ args: command line arguments
365
+ device: device to use
366
+ dit_dtype: data type for the model
367
+ dit_weight_dtype: data type for the model weights. None for as-is
368
+
369
+ Returns:
370
+ HunyuanVideoTransformer3DModelPacked: DiT model
371
+ """
372
+ loading_device = "cpu"
373
+ if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None:
374
+ loading_device = device
375
+
376
+ # do not fp8 optimize because we will merge LoRA weights
377
+ model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
378
+
379
+ # apply RoPE scaling factor
380
+ if args.rope_scaling_timestep_threshold is not None:
381
+ logger.info(
382
+ f"Applying RoPE scaling factor {args.rope_scaling_factor} for timesteps >= {args.rope_scaling_timestep_threshold}"
383
+ )
384
+ model.enable_rope_scaling(args.rope_scaling_timestep_threshold, args.rope_scaling_factor)
385
+ return model
386
+
387
+
388
+ def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None:
389
+ """optimize the model (FP8 conversion, device move etc.)
390
+
391
+ Args:
392
+ model: dit model
393
+ args: command line arguments
394
+ device: device to use
395
+ """
396
+ if args.fp8_scaled:
397
+ # load state dict as-is and optimize to fp8
398
+ state_dict = model.state_dict()
399
+
400
+ # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
401
+ move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
402
+ state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
403
+
404
+ info = model.load_state_dict(state_dict, strict=True, assign=True)
405
+ logger.info(f"Loaded FP8 optimized weights: {info}")
406
+
407
+ if args.blocks_to_swap == 0:
408
+ model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
409
+ else:
410
+ # simple cast to dit_dtype
411
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
412
+ target_device = None
413
+
414
+ if args.fp8:
415
+ target_dtype = torch.float8e4m3fn
416
+
417
+ if args.blocks_to_swap == 0:
418
+ logger.info(f"Move model to device: {device}")
419
+ target_device = device
420
+
421
+ if target_device is not None and target_dtype is not None:
422
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
423
+
424
+ # if args.compile:
425
+ # compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
426
+ # logger.info(
427
+ # f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
428
+ # )
429
+ # torch._dynamo.config.cache_size_limit = 32
430
+ # for i in range(len(model.blocks)):
431
+ # model.blocks[i] = torch.compile(
432
+ # model.blocks[i],
433
+ # backend=compile_backend,
434
+ # mode=compile_mode,
435
+ # dynamic=compile_dynamic.lower() in "true",
436
+ # fullgraph=compile_fullgraph.lower() in "true",
437
+ # )
438
+
439
+ if args.blocks_to_swap > 0:
440
+ logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
441
+ model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
442
+ model.move_to_device_except_swap_blocks(device)
443
+ model.prepare_block_swap_before_forward()
444
+ else:
445
+ # make sure the model is on the right device
446
+ model.to(device)
447
+
448
+ model.eval().requires_grad_(False)
449
+ clean_memory_on_device(device)
450
+
451
+
452
+ # endregion
453
+
454
+
455
+ def decode_latent(
456
+ latent_window_size: int,
457
+ total_latent_sections: int,
458
+ bulk_decode: bool,
459
+ vae: AutoencoderKLCausal3D,
460
+ latent: torch.Tensor,
461
+ device: torch.device,
462
+ one_frame_inference_mode: bool = False,
463
+ ) -> torch.Tensor:
464
+ logger.info(f"Decoding video...")
465
+ if latent.ndim == 4:
466
+ latent = latent.unsqueeze(0) # add batch dimension
467
+
468
+ vae.to(device)
469
+ if not bulk_decode and not one_frame_inference_mode:
470
+ latent_window_size = latent_window_size # default is 9
471
+ # total_latent_sections = (args.video_seconds * 30) / (latent_window_size * 4)
472
+ # total_latent_sections = int(max(round(total_latent_sections), 1))
473
+ num_frames = latent_window_size * 4 - 3
474
+
475
+ latents_to_decode = []
476
+ latent_frame_index = 0
477
+ for i in range(total_latent_sections - 1, -1, -1):
478
+ is_last_section = i == total_latent_sections - 1
479
+ generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0)
480
+ section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
481
+
482
+ section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :]
483
+ if section_latent.shape[2] > 0:
484
+ latents_to_decode.append(section_latent)
485
+
486
+ latent_frame_index += generated_latent_frames
487
+
488
+ latents_to_decode = latents_to_decode[::-1] # reverse the order of latents to decode
489
+
490
+ history_pixels = None
491
+ for latent in tqdm(latents_to_decode):
492
+ if history_pixels is None:
493
+ history_pixels = hunyuan.vae_decode(latent, vae).cpu()
494
+ else:
495
+ overlapped_frames = latent_window_size * 4 - 3
496
+ current_pixels = hunyuan.vae_decode(latent, vae).cpu()
497
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
498
+ clean_memory_on_device(device)
499
+ else:
500
+ # bulk decode
501
+ logger.info(f"Bulk decoding or one frame inference")
502
+ if not one_frame_inference_mode:
503
+ history_pixels = hunyuan.vae_decode(latent, vae).cpu() # normal
504
+ else:
505
+ # one frame inference
506
+ history_pixels = [hunyuan.vae_decode(latent[:, :, i : i + 1, :, :], vae).cpu() for i in range(latent.shape[2])]
507
+ history_pixels = torch.cat(history_pixels, dim=2)
508
+
509
+ vae.to("cpu")
510
+
511
+ logger.info(f"Decoded. Pixel shape {history_pixels.shape}")
512
+ return history_pixels[0] # remove batch dimension
513
+
514
+
515
+ def prepare_i2v_inputs(
516
+ args: argparse.Namespace,
517
+ device: torch.device,
518
+ vae: AutoencoderKLCausal3D,
519
+ shared_models: Optional[Dict] = None,
520
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
521
+ """Prepare inputs for I2V
522
+
523
+ Args:
524
+ args: command line arguments
525
+ config: model configuration
526
+ device: device to use
527
+ vae: VAE model, used for image encoding
528
+ shared_models: dictionary containing pre-loaded models
529
+
530
+ Returns:
531
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
532
+ (noise, context, context_null, y, (arg_c, arg_null))
533
+ """
534
+
535
+ height, width, video_seconds = check_inputs(args)
536
+
537
+ # define parsing function
538
+ def parse_section_strings(input_string: str) -> dict[int, str]:
539
+ section_strings = {}
540
+ if ";;;" in input_string:
541
+ split_section_strings = input_string.split(";;;")
542
+ for section_str in split_section_strings:
543
+ if ":" not in section_str:
544
+ start = end = 0
545
+ section_str = section_str.strip()
546
+ else:
547
+ index_str, section_str = section_str.split(":", 1)
548
+ index_str = index_str.strip()
549
+ section_str = section_str.strip()
550
+
551
+ m = re.match(r"^(-?\d+)(-\d+)?$", index_str)
552
+ if m:
553
+ start = int(m.group(1))
554
+ end = int(m.group(2)[1:]) if m.group(2) is not None else start
555
+ else:
556
+ start = end = 0
557
+ section_str = section_str.strip()
558
+ for i in range(start, end + 1):
559
+ section_strings[i] = section_str
560
+ else:
561
+ section_strings[0] = input_string
562
+
563
+ # assert 0 in section_prompts, "Section prompts must contain section 0"
564
+ if 0 not in section_strings:
565
+ # use smallest section index. prefer positive index over negative index
566
+ # if all section indices are negative, use the smallest negative index
567
+ indices = list(section_strings.keys())
568
+ if all(i < 0 for i in indices):
569
+ section_index = min(indices)
570
+ else:
571
+ section_index = min(i for i in indices if i >= 0)
572
+ section_strings[0] = section_strings[section_index]
573
+ return section_strings
574
+
575
+ # prepare image
576
+ def preprocess_image(image_path: str):
577
+ image = Image.open(image_path)
578
+ if image.mode == "RGBA":
579
+ alpha = image.split()[-1]
580
+ else:
581
+ alpha = None
582
+ image = image.convert("RGB")
583
+
584
+ image_np = np.array(image) # PIL to numpy, HWC
585
+
586
+ image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
587
+ image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
588
+ image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
589
+ return image_tensor, image_np, alpha
590
+
591
+ section_image_paths = parse_section_strings(args.image_path)
592
+
593
+ section_images = {}
594
+ for index, image_path in section_image_paths.items():
595
+ img_tensor, img_np, _ = preprocess_image(image_path)
596
+ section_images[index] = (img_tensor, img_np)
597
+
598
+ # check end image
599
+ if args.end_image_path is not None:
600
+ end_image_tensor, _, _ = preprocess_image(args.end_image_path)
601
+ else:
602
+ end_image_tensor = None
603
+
604
+ # check end images
605
+ if args.control_image_path is not None and len(args.control_image_path) > 0:
606
+ control_image_tensors = []
607
+ control_mask_images = []
608
+ for ctrl_image_path in args.control_image_path:
609
+ control_image_tensor, _, control_mask = preprocess_image(ctrl_image_path)
610
+ control_image_tensors.append(control_image_tensor)
611
+ control_mask_images.append(control_mask)
612
+ else:
613
+ control_image_tensors = None
614
+ control_mask_images = None
615
+
616
+ # configure negative prompt
617
+ n_prompt = args.negative_prompt if args.negative_prompt else ""
618
+
619
+ # parse section prompts
620
+ section_prompts = parse_section_strings(args.prompt)
621
+
622
+ # load text encoder
623
+ if shared_models is not None:
624
+ tokenizer1, text_encoder1 = shared_models["tokenizer1"], shared_models["text_encoder1"]
625
+ tokenizer2, text_encoder2 = shared_models["tokenizer2"], shared_models["text_encoder2"]
626
+ text_encoder1.to(device)
627
+ else:
628
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
629
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
630
+ text_encoder2.to(device)
631
+
632
+ logger.info(f"Encoding prompt")
633
+ llama_vecs = {}
634
+ llama_attention_masks = {}
635
+ clip_l_poolers = {}
636
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
637
+ for index, prompt in section_prompts.items():
638
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(
639
+ prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
640
+ )
641
+ llama_vec = llama_vec.cpu()
642
+ clip_l_pooler = clip_l_pooler.cpu()
643
+
644
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
645
+
646
+ llama_vecs[index] = llama_vec
647
+ llama_attention_masks[index] = llama_attention_mask
648
+ clip_l_poolers[index] = clip_l_pooler
649
+
650
+ if args.guidance_scale == 1.0:
651
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vecs[0]), torch.zeros_like(clip_l_poolers[0])
652
+ else:
653
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
654
+ llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(
655
+ n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
656
+ )
657
+ llama_vec_n = llama_vec_n.cpu()
658
+ clip_l_pooler_n = clip_l_pooler_n.cpu()
659
+
660
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
661
+
662
+ # free text encoder and clean memory
663
+ if shared_models is not None: # if shared models are used, do not free them but move to CPU
664
+ text_encoder1.to("cpu")
665
+ text_encoder2.to("cpu")
666
+ del tokenizer1, text_encoder1, tokenizer2, text_encoder2 # do not free shared models
667
+ clean_memory_on_device(device)
668
+
669
+ # load image encoder
670
+ if shared_models is not None:
671
+ feature_extractor, image_encoder = shared_models["feature_extractor"], shared_models["image_encoder"]
672
+ else:
673
+ feature_extractor, image_encoder = load_image_encoders(args)
674
+ image_encoder.to(device)
675
+
676
+ # encode image with image encoder
677
+
678
+ section_image_encoder_last_hidden_states = {}
679
+ for index, (img_tensor, img_np) in section_images.items():
680
+ with torch.no_grad():
681
+ image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
682
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu()
683
+ section_image_encoder_last_hidden_states[index] = image_encoder_last_hidden_state
684
+
685
+ # free image encoder and clean memory
686
+ if shared_models is not None:
687
+ image_encoder.to("cpu")
688
+ del image_encoder, feature_extractor
689
+ clean_memory_on_device(device)
690
+
691
+ # VAE encoding
692
+ logger.info(f"Encoding image to latent space")
693
+ vae.to(device)
694
+
695
+ section_start_latents = {}
696
+ for index, (img_tensor, img_np) in section_images.items():
697
+ start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
698
+ section_start_latents[index] = start_latent
699
+
700
+ end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
701
+
702
+ control_latents = None
703
+ if control_image_tensors is not None:
704
+ control_latents = []
705
+ for ctrl_image_tensor in control_image_tensors:
706
+ control_latent = hunyuan.vae_encode(ctrl_image_tensor, vae).cpu()
707
+ control_latents.append(control_latent)
708
+
709
+ vae.to("cpu") # move VAE to CPU to save memory
710
+ clean_memory_on_device(device)
711
+
712
+ # prepare model input arguments
713
+ arg_c = {}
714
+ arg_null = {}
715
+ for index in llama_vecs.keys():
716
+ llama_vec = llama_vecs[index]
717
+ llama_attention_mask = llama_attention_masks[index]
718
+ clip_l_pooler = clip_l_poolers[index]
719
+ arg_c_i = {
720
+ "llama_vec": llama_vec,
721
+ "llama_attention_mask": llama_attention_mask,
722
+ "clip_l_pooler": clip_l_pooler,
723
+ "prompt": section_prompts[index], # for debugging
724
+ }
725
+ arg_c[index] = arg_c_i
726
+
727
+ arg_null = {
728
+ "llama_vec": llama_vec_n,
729
+ "llama_attention_mask": llama_attention_mask_n,
730
+ "clip_l_pooler": clip_l_pooler_n,
731
+ }
732
+
733
+ arg_c_img = {}
734
+ for index in section_images.keys():
735
+ image_encoder_last_hidden_state = section_image_encoder_last_hidden_states[index]
736
+ start_latent = section_start_latents[index]
737
+ arg_c_img_i = {
738
+ "image_encoder_last_hidden_state": image_encoder_last_hidden_state,
739
+ "start_latent": start_latent,
740
+ "image_path": section_image_paths[index],
741
+ }
742
+ arg_c_img[index] = arg_c_img_i
743
+
744
+ return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latent, control_latents, control_mask_images
745
+
746
+
747
+ # def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
748
+ # """setup scheduler for sampling
749
+
750
+ # Args:
751
+ # args: command line arguments
752
+ # config: model configuration
753
+ # device: device to use
754
+
755
+ # Returns:
756
+ # Tuple[Any, torch.Tensor]: (scheduler, timesteps)
757
+ # """
758
+ # if args.sample_solver == "unipc":
759
+ # scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
760
+ # scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
761
+ # timesteps = scheduler.timesteps
762
+ # elif args.sample_solver == "dpm++":
763
+ # scheduler = FlowDPMSolverMultistepScheduler(
764
+ # num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
765
+ # )
766
+ # sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
767
+ # timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
768
+ # elif args.sample_solver == "vanilla":
769
+ # scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
770
+ # scheduler.set_timesteps(args.infer_steps, device=device)
771
+ # timesteps = scheduler.timesteps
772
+
773
+ # # FlowMatchDiscreteScheduler does not support generator argument in step method
774
+ # org_step = scheduler.step
775
+
776
+ # def step_wrapper(
777
+ # model_output: torch.Tensor,
778
+ # timestep: Union[int, torch.Tensor],
779
+ # sample: torch.Tensor,
780
+ # return_dict: bool = True,
781
+ # generator=None,
782
+ # ):
783
+ # return org_step(model_output, timestep, sample, return_dict=return_dict)
784
+
785
+ # scheduler.step = step_wrapper
786
+ # else:
787
+ # raise NotImplementedError("Unsupported solver.")
788
+
789
+ # return scheduler, timesteps
790
+
791
+
792
+ def convert_lora_for_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
793
+ # Check the format of the LoRA file
794
+ keys = list(lora_sd.keys())
795
+ if keys[0].startswith("lora_unet_"):
796
+ # logging.info(f"Musubi Tuner LoRA detected")
797
+ pass
798
+
799
+ else:
800
+ transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
801
+ lora_suffix = None
802
+ prefix = None
803
+ for key in keys:
804
+ if lora_suffix is None and "lora_A" in key:
805
+ lora_suffix = "lora_A"
806
+ if prefix is None:
807
+ pfx = key.split(".")[0]
808
+ if pfx in transformer_prefixes:
809
+ prefix = pfx
810
+ if lora_suffix is not None and prefix is not None:
811
+ break
812
+
813
+ if lora_suffix == "lora_A" and prefix is not None:
814
+ logging.info(f"Diffusion-pipe (?) LoRA detected, converting to the default LoRA format")
815
+ lora_sd = convert_lora_from_diffusion_pipe_or_something(lora_sd, "lora_unet_")
816
+
817
+ else:
818
+ logging.info(f"LoRA file format not recognized. Using it as-is.")
819
+
820
+ # Check LoRA is for FramePack or for HunyuanVideo
821
+ is_hunyuan = False
822
+ for key in lora_sd.keys():
823
+ if "double_blocks" in key or "single_blocks" in key:
824
+ is_hunyuan = True
825
+ break
826
+ if is_hunyuan:
827
+ logging.info("HunyuanVideo LoRA detected, converting to FramePack format")
828
+ lora_sd = convert_hunyuan_to_framepack(lora_sd)
829
+
830
+ return lora_sd
831
+
832
+
833
+ def convert_lora_from_diffusion_pipe_or_something(lora_sd: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
834
+ """
835
+ Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
836
+ Copy from Musubi Tuner repo.
837
+ """
838
+ # convert from diffusers(?) to default LoRA
839
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
840
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
841
+
842
+ # note: Diffusers has no alpha, so alpha is set to rank
843
+ new_weights_sd = {}
844
+ lora_dims = {}
845
+ for key, weight in lora_sd.items():
846
+ diffusers_prefix, key_body = key.split(".", 1)
847
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
848
+ print(f"unexpected key: {key} in diffusers format")
849
+ continue
850
+
851
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
852
+ new_weights_sd[new_key] = weight
853
+
854
+ lora_name = new_key.split(".")[0] # before first dot
855
+ if lora_name not in lora_dims and "lora_down" in new_key:
856
+ lora_dims[lora_name] = weight.shape[0]
857
+
858
+ # add alpha with rank
859
+ for lora_name, dim in lora_dims.items():
860
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
861
+
862
+ return new_weights_sd
863
+
864
+
865
+ def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
866
+ """
867
+ Convert HunyuanVideo LoRA weights to FramePack format.
868
+ """
869
+ new_lora_sd = {}
870
+ for key, weight in lora_sd.items():
871
+ if "double_blocks" in key:
872
+ key = key.replace("double_blocks", "transformer_blocks")
873
+ key = key.replace("img_mod_linear", "norm1_linear")
874
+ key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
875
+ key = key.replace("img_attn_proj", "attn_to_out_0")
876
+ key = key.replace("img_mlp_fc1", "ff_net_0_proj")
877
+ key = key.replace("img_mlp_fc2", "ff_net_2")
878
+ key = key.replace("txt_mod_linear", "norm1_context_linear")
879
+ key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
880
+ key = key.replace("txt_attn_proj", "attn_to_add_out")
881
+ key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
882
+ key = key.replace("txt_mlp_fc2", "ff_context_net_2")
883
+ elif "single_blocks" in key:
884
+ key = key.replace("single_blocks", "single_transformer_blocks")
885
+ key = key.replace("linear1", "attn_to_QKVM") # split later
886
+ key = key.replace("linear2", "proj_out")
887
+ key = key.replace("modulation_linear", "norm_linear")
888
+ else:
889
+ print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
890
+ continue
891
+
892
+ if "QKVM" in key:
893
+ # split QKVM into Q, K, V, M
894
+ key_q = key.replace("QKVM", "q")
895
+ key_k = key.replace("QKVM", "k")
896
+ key_v = key.replace("QKVM", "v")
897
+ key_m = key.replace("attn_to_QKVM", "proj_mlp")
898
+ if "_down" in key or "alpha" in key:
899
+ # copy QKVM weight or alpha to Q, K, V, M
900
+ assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
901
+ new_lora_sd[key_q] = weight
902
+ new_lora_sd[key_k] = weight
903
+ new_lora_sd[key_v] = weight
904
+ new_lora_sd[key_m] = weight
905
+ elif "_up" in key:
906
+ # split QKVM weight into Q, K, V, M
907
+ assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
908
+ new_lora_sd[key_q] = weight[:3072]
909
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
910
+ new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
911
+ new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
912
+ else:
913
+ print(f"Unsupported module name: {key}")
914
+ continue
915
+ elif "QKV" in key:
916
+ # split QKV into Q, K, V
917
+ key_q = key.replace("QKV", "q")
918
+ key_k = key.replace("QKV", "k")
919
+ key_v = key.replace("QKV", "v")
920
+ if "_down" in key or "alpha" in key:
921
+ # copy QKV weight or alpha to Q, K, V
922
+ assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
923
+ new_lora_sd[key_q] = weight
924
+ new_lora_sd[key_k] = weight
925
+ new_lora_sd[key_v] = weight
926
+ elif "_up" in key:
927
+ # split QKV weight into Q, K, V
928
+ assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
929
+ new_lora_sd[key_q] = weight[:3072]
930
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
931
+ new_lora_sd[key_v] = weight[3072 * 2 :]
932
+ else:
933
+ print(f"Unsupported module name: {key}")
934
+ continue
935
+ else:
936
+ # no split needed
937
+ new_lora_sd[key] = weight
938
+
939
+ return new_lora_sd
940
+
941
+
942
+ def generate(
943
+ args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None
944
+ ) -> tuple[AutoencoderKLCausal3D, torch.Tensor]:
945
+ """main function for generation
946
+
947
+ Args:
948
+ args: command line arguments
949
+ shared_models: dictionary containing pre-loaded models
950
+
951
+ Returns:
952
+ tuple: (AutoencoderKLCausal3D model (vae), torch.Tensor generated latent)
953
+ """
954
+ device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype)
955
+
956
+ # prepare seed
957
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
958
+ args.seed = seed # set seed to args for saving
959
+
960
+ # Check if we have shared models
961
+ if shared_models is not None:
962
+ # Use shared models and encoded data
963
+ vae = shared_models.get("vae")
964
+ height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
965
+ prepare_i2v_inputs(args, device, vae, shared_models)
966
+ )
967
+ else:
968
+ # prepare inputs without shared models
969
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
970
+ height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
971
+ prepare_i2v_inputs(args, device, vae)
972
+ )
973
+
974
+ if shared_models is None or "model" not in shared_models:
975
+ # load DiT model
976
+ model = load_dit_model(args, device)
977
+
978
+ # merge LoRA weights
979
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
980
+ # ugly hack to common merge_lora_weights function
981
+ merge_lora_weights(lora_framepack, model, args, device, convert_lora_for_framepack)
982
+
983
+ # if we only want to save the model, we can skip the rest
984
+ if args.save_merged_model:
985
+ return None, None
986
+
987
+ # optimize model: fp8 conversion, block swap etc.
988
+ optimize_model(model, args, device)
989
+
990
+ if shared_models is not None:
991
+ shared_models["model"] = model
992
+ else:
993
+ # use shared model
994
+ model: HunyuanVideoTransformer3DModelPacked = shared_models["model"]
995
+ model.move_to_device_except_swap_blocks(device)
996
+ model.prepare_block_swap_before_forward()
997
+
998
+ # sampling
999
+ latent_window_size = args.latent_window_size # default is 9
1000
+ # ex: (5s * 30fps) / (9 * 4) = 4.16 -> 4 sections, 60s -> 1800 / 36 = 50 sections
1001
+ total_latent_sections = (video_seconds * 30) / (latent_window_size * 4)
1002
+ total_latent_sections = int(max(round(total_latent_sections), 1))
1003
+
1004
+ # set random generator
1005
+ seed_g = torch.Generator(device="cpu")
1006
+ seed_g.manual_seed(seed)
1007
+ num_frames = latent_window_size * 4 - 3
1008
+
1009
+ logger.info(
1010
+ f"Video size: {height}x{width}@{video_seconds} (HxW@seconds), fps: {args.fps}, num sections: {total_latent_sections}, "
1011
+ f"infer_steps: {args.infer_steps}, frames per generation: {num_frames}"
1012
+ )
1013
+
1014
+ # video generation ######
1015
+ f1_mode = args.f1
1016
+ one_frame_inference = None
1017
+ if args.one_frame_inference is not None:
1018
+ one_frame_inference = set()
1019
+ for mode in args.one_frame_inference.split(","):
1020
+ one_frame_inference.add(mode.strip())
1021
+
1022
+ if one_frame_inference is not None:
1023
+ real_history_latents = generate_with_one_frame_inference(
1024
+ args,
1025
+ model,
1026
+ context,
1027
+ context_null,
1028
+ context_img,
1029
+ control_latents,
1030
+ control_mask_images,
1031
+ latent_window_size,
1032
+ height,
1033
+ width,
1034
+ device,
1035
+ seed_g,
1036
+ one_frame_inference,
1037
+ )
1038
+ else:
1039
+ # prepare history latents
1040
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
1041
+ if end_latent is not None and not f1_mode:
1042
+ logger.info(f"Use end image(s): {args.end_image_path}")
1043
+ history_latents[:, :, :1] = end_latent.to(history_latents)
1044
+
1045
+ # prepare clean latents and indices
1046
+ if not f1_mode:
1047
+ # Inverted Anti-drifting
1048
+ total_generated_latent_frames = 0
1049
+ latent_paddings = reversed(range(total_latent_sections))
1050
+
1051
+ if total_latent_sections > 4 and one_frame_inference is None:
1052
+ # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
1053
+ # items looks better than expanding it when total_latent_sections > 4
1054
+ # One can try to remove below trick and just
1055
+ # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
1056
+ # 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
1057
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
1058
+
1059
+ if args.latent_paddings is not None:
1060
+ # parse user defined latent paddings
1061
+ user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
1062
+ if len(user_latent_paddings) < total_latent_sections:
1063
+ print(
1064
+ f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
1065
+ )
1066
+ print(f"Use default paddings instead for unspecified sections.")
1067
+ latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
1068
+ elif len(user_latent_paddings) > total_latent_sections:
1069
+ print(
1070
+ f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
1071
+ )
1072
+ print(f"Use only first {total_latent_sections} paddings instead.")
1073
+ latent_paddings = user_latent_paddings[:total_latent_sections]
1074
+ else:
1075
+ latent_paddings = user_latent_paddings
1076
+ else:
1077
+ start_latent = context_img[0]["start_latent"]
1078
+ history_latents = torch.cat([history_latents, start_latent], dim=2)
1079
+ total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
1080
+ latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
1081
+
1082
+ latent_paddings = list(latent_paddings) # make sure it's a list
1083
+ for loop_index in range(total_latent_sections):
1084
+ latent_padding = latent_paddings[loop_index]
1085
+
1086
+ if not f1_mode:
1087
+ # Inverted Anti-drifting
1088
+ section_index_reverse = loop_index # 0, 1, 2, 3
1089
+ section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
1090
+ section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
1091
+
1092
+ is_last_section = section_index == 0
1093
+ is_first_section = section_index_reverse == 0
1094
+ latent_padding_size = latent_padding * latent_window_size
1095
+
1096
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
1097
+ else:
1098
+ section_index = loop_index # 0, 1, 2, 3
1099
+ section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
1100
+ is_last_section = loop_index == total_latent_sections - 1
1101
+ is_first_section = loop_index == 0
1102
+ latent_padding_size = 0 # dummy padding for F1 mode
1103
+
1104
+ # select start latent
1105
+ if section_index_from_last in context_img:
1106
+ image_index = section_index_from_last
1107
+ elif section_index in context_img:
1108
+ image_index = section_index
1109
+ else:
1110
+ image_index = 0
1111
+
1112
+ start_latent = context_img[image_index]["start_latent"]
1113
+ image_path = context_img[image_index]["image_path"]
1114
+ if image_index != 0: # use section image other than section 0
1115
+ logger.info(
1116
+ f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}"
1117
+ )
1118
+
1119
+ if not f1_mode:
1120
+ # Inverted Anti-drifting
1121
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
1122
+ (
1123
+ clean_latent_indices_pre,
1124
+ blank_indices,
1125
+ latent_indices,
1126
+ clean_latent_indices_post,
1127
+ clean_latent_2x_indices,
1128
+ clean_latent_4x_indices,
1129
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
1130
+
1131
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
1132
+
1133
+ clean_latents_pre = start_latent.to(history_latents)
1134
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
1135
+ [1, 2, 16], dim=2
1136
+ )
1137
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
1138
+
1139
+ else:
1140
+ # F1 mode
1141
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
1142
+ (
1143
+ clean_latent_indices_start,
1144
+ clean_latent_4x_indices,
1145
+ clean_latent_2x_indices,
1146
+ clean_latent_1x_indices,
1147
+ latent_indices,
1148
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
1149
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
1150
+
1151
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
1152
+ [16, 2, 1], dim=2
1153
+ )
1154
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
1155
+
1156
+ # if use_teacache:
1157
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
1158
+ # else:
1159
+ # transformer.initialize_teacache(enable_teacache=False)
1160
+
1161
+ # prepare conditioning inputs
1162
+ if section_index_from_last in context:
1163
+ prompt_index = section_index_from_last
1164
+ elif section_index in context:
1165
+ prompt_index = section_index
1166
+ else:
1167
+ prompt_index = 0
1168
+
1169
+ context_for_index = context[prompt_index]
1170
+ # if args.section_prompts is not None:
1171
+ logger.info(f"Section {section_index}: {context_for_index['prompt']}")
1172
+
1173
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
1174
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
1175
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1176
+
1177
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
1178
+ device, dtype=torch.bfloat16
1179
+ )
1180
+
1181
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
1182
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
1183
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1184
+
1185
+ generated_latents = sample_hunyuan(
1186
+ transformer=model,
1187
+ sampler=args.sample_solver,
1188
+ width=width,
1189
+ height=height,
1190
+ frames=num_frames,
1191
+ real_guidance_scale=args.guidance_scale,
1192
+ distilled_guidance_scale=args.embedded_cfg_scale,
1193
+ guidance_rescale=args.guidance_rescale,
1194
+ # shift=3.0,
1195
+ num_inference_steps=args.infer_steps,
1196
+ generator=seed_g,
1197
+ prompt_embeds=llama_vec,
1198
+ prompt_embeds_mask=llama_attention_mask,
1199
+ prompt_poolers=clip_l_pooler,
1200
+ negative_prompt_embeds=llama_vec_n,
1201
+ negative_prompt_embeds_mask=llama_attention_mask_n,
1202
+ negative_prompt_poolers=clip_l_pooler_n,
1203
+ device=device,
1204
+ dtype=torch.bfloat16,
1205
+ image_embeddings=image_encoder_last_hidden_state,
1206
+ latent_indices=latent_indices,
1207
+ clean_latents=clean_latents,
1208
+ clean_latent_indices=clean_latent_indices,
1209
+ clean_latents_2x=clean_latents_2x,
1210
+ clean_latent_2x_indices=clean_latent_2x_indices,
1211
+ clean_latents_4x=clean_latents_4x,
1212
+ clean_latent_4x_indices=clean_latent_4x_indices,
1213
+ )
1214
+
1215
+ # concatenate generated latents
1216
+ total_generated_latent_frames += int(generated_latents.shape[2])
1217
+ if not f1_mode:
1218
+ # Inverted Anti-drifting: prepend generated latents to history latents
1219
+ if is_last_section:
1220
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
1221
+ total_generated_latent_frames += 1
1222
+
1223
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
1224
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
1225
+ else:
1226
+ # F1 mode: append generated latents to history latents
1227
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
1228
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
1229
+
1230
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
1231
+
1232
+ # # TODO support saving intermediate video
1233
+ # clean_memory_on_device(device)
1234
+ # vae.to(device)
1235
+ # if history_pixels is None:
1236
+ # history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
1237
+ # else:
1238
+ # section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
1239
+ # overlapped_frames = latent_window_size * 4 - 3
1240
+ # current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
1241
+ # history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
1242
+ # vae.to("cpu")
1243
+ # # if not is_last_section:
1244
+ # # # save intermediate video
1245
+ # # save_video(history_pixels[0], args, total_generated_latent_frames)
1246
+ # print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
1247
+
1248
+ # Only clean up shared models if they were created within this function
1249
+ if shared_models is None:
1250
+ del model # free memory
1251
+ synchronize_device(device)
1252
+ else:
1253
+ # move model to CPU to save memory
1254
+ model.to("cpu")
1255
+
1256
+ # wait for 5 seconds until block swap is done
1257
+ if args.blocks_to_swap > 0:
1258
+ logger.info("Waiting for 5 seconds to finish block swap")
1259
+ time.sleep(5)
1260
+
1261
+ gc.collect()
1262
+ clean_memory_on_device(device)
1263
+
1264
+ return vae, real_history_latents
1265
+
1266
+
1267
+ def generate_with_one_frame_inference(
1268
+ args: argparse.Namespace,
1269
+ model: HunyuanVideoTransformer3DModelPacked,
1270
+ context: Dict[int, Dict[str, torch.Tensor]],
1271
+ context_null: Dict[str, torch.Tensor],
1272
+ context_img: Dict[int, Dict[str, torch.Tensor]],
1273
+ control_latents: Optional[List[torch.Tensor]],
1274
+ control_mask_images: Optional[List[Optional[Image.Image]]],
1275
+ latent_window_size: int,
1276
+ height: int,
1277
+ width: int,
1278
+ device: torch.device,
1279
+ seed_g: torch.Generator,
1280
+ one_frame_inference: set[str],
1281
+ ) -> torch.Tensor:
1282
+ # one frame inference
1283
+ sample_num_frames = 1
1284
+ latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
1285
+ latent_indices[:, 0] = latent_window_size # last of latent_window
1286
+
1287
+ def get_latent_mask(mask_image: Image.Image) -> torch.Tensor:
1288
+ if mask_image.mode != "L":
1289
+ mask_image = mask_image.convert("L")
1290
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
1291
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
1292
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
1293
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
1294
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
1295
+ mask_image = mask_image.to(torch.float32)
1296
+ return mask_image
1297
+
1298
+ if control_latents is None or len(control_latents) == 0:
1299
+ logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
1300
+ control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
1301
+
1302
+ if "no_post" not in one_frame_inference:
1303
+ # add zero latents as clean latents post
1304
+ control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
1305
+ logger.info(f"Add zero latents as clean latents post for one frame inference.")
1306
+
1307
+ # kisekaeichi and 1f-mc: both are using control images, but indices are different
1308
+ clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
1309
+ clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
1310
+ if "no_post" not in one_frame_inference:
1311
+ clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
1312
+
1313
+ for i in range(len(control_latents)):
1314
+ mask_image = None
1315
+ if args.control_image_mask_path is not None and i < len(args.control_image_mask_path):
1316
+ mask_image = get_latent_mask(Image.open(args.control_image_mask_path[i]))
1317
+ logger.info(
1318
+ f"Apply mask for clean latents 1x for {i + 1}: {args.control_image_mask_path[i]}, shape: {mask_image.shape}"
1319
+ )
1320
+ elif control_mask_images is not None and i < len(control_mask_images) and control_mask_images[i] is not None:
1321
+ mask_image = get_latent_mask(control_mask_images[i])
1322
+ logger.info(f"Apply mask for clean latents 1x for {i + 1} with alpha channel: {mask_image.shape}")
1323
+ if mask_image is not None:
1324
+ clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * mask_image
1325
+
1326
+ for one_frame_param in one_frame_inference:
1327
+ if one_frame_param.startswith("target_index="):
1328
+ target_index = int(one_frame_param.split("=")[1])
1329
+ latent_indices[:, 0] = target_index
1330
+ logger.info(f"Set index for target: {target_index}")
1331
+ elif one_frame_param.startswith("control_index="):
1332
+ control_indices = one_frame_param.split("=")[1].split(";")
1333
+ i = 0
1334
+ while i < len(control_indices) and i < clean_latent_indices.shape[1]:
1335
+ control_index = int(control_indices[i])
1336
+ clean_latent_indices[:, i] = control_index
1337
+ i += 1
1338
+ logger.info(f"Set index for clean latent 1x: {control_indices}")
1339
+
1340
+ # "default" option does nothing, so we can skip it
1341
+ if "default" in one_frame_inference:
1342
+ pass
1343
+
1344
+ if "no_2x" in one_frame_inference:
1345
+ clean_latents_2x = None
1346
+ clean_latent_2x_indices = None
1347
+ logger.info(f"No clean_latents_2x")
1348
+ else:
1349
+ clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
1350
+ index = 1 + latent_window_size + 1
1351
+ clean_latent_2x_indices = torch.arange(index, index + 2).unsqueeze(0) # 2
1352
+
1353
+ if "no_4x" in one_frame_inference:
1354
+ clean_latents_4x = None
1355
+ clean_latent_4x_indices = None
1356
+ logger.info(f"No clean_latents_4x")
1357
+ else:
1358
+ clean_latents_4x = torch.zeros((1, 16, 16, height // 8, width // 8), dtype=torch.float32)
1359
+ index = 1 + latent_window_size + 1 + 2
1360
+ clean_latent_4x_indices = torch.arange(index, index + 16).unsqueeze(0) # 16
1361
+
1362
+ logger.info(
1363
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
1364
+ )
1365
+
1366
+ # prepare conditioning inputs
1367
+ prompt_index = 0
1368
+ image_index = 0
1369
+
1370
+ context_for_index = context[prompt_index]
1371
+ logger.info(f"Prompt: {context_for_index['prompt']}")
1372
+
1373
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
1374
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
1375
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1376
+
1377
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(device, dtype=torch.bfloat16)
1378
+
1379
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
1380
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
1381
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1382
+
1383
+ generated_latents = sample_hunyuan(
1384
+ transformer=model,
1385
+ sampler=args.sample_solver,
1386
+ width=width,
1387
+ height=height,
1388
+ frames=1,
1389
+ real_guidance_scale=args.guidance_scale,
1390
+ distilled_guidance_scale=args.embedded_cfg_scale,
1391
+ guidance_rescale=args.guidance_rescale,
1392
+ # shift=3.0,
1393
+ num_inference_steps=args.infer_steps,
1394
+ generator=seed_g,
1395
+ prompt_embeds=llama_vec,
1396
+ prompt_embeds_mask=llama_attention_mask,
1397
+ prompt_poolers=clip_l_pooler,
1398
+ negative_prompt_embeds=llama_vec_n,
1399
+ negative_prompt_embeds_mask=llama_attention_mask_n,
1400
+ negative_prompt_poolers=clip_l_pooler_n,
1401
+ device=device,
1402
+ dtype=torch.bfloat16,
1403
+ image_embeddings=image_encoder_last_hidden_state,
1404
+ latent_indices=latent_indices,
1405
+ clean_latents=clean_latents,
1406
+ clean_latent_indices=clean_latent_indices,
1407
+ clean_latents_2x=clean_latents_2x,
1408
+ clean_latent_2x_indices=clean_latent_2x_indices,
1409
+ clean_latents_4x=clean_latents_4x,
1410
+ clean_latent_4x_indices=clean_latent_4x_indices,
1411
+ )
1412
+
1413
+ real_history_latents = generated_latents.to(clean_latents)
1414
+ return real_history_latents
1415
+
1416
+
1417
+ def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
1418
+ """Save latent to file
1419
+
1420
+ Args:
1421
+ latent: Latent tensor
1422
+ args: command line arguments
1423
+ height: height of frame
1424
+ width: width of frame
1425
+
1426
+ Returns:
1427
+ str: Path to saved latent file
1428
+ """
1429
+ save_path = args.save_path
1430
+ os.makedirs(save_path, exist_ok=True)
1431
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1432
+
1433
+ seed = args.seed
1434
+ video_seconds = args.video_seconds
1435
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
1436
+
1437
+ if args.no_metadata:
1438
+ metadata = None
1439
+ else:
1440
+ metadata = {
1441
+ "seeds": f"{seed}",
1442
+ "prompt": f"{args.prompt}",
1443
+ "height": f"{height}",
1444
+ "width": f"{width}",
1445
+ "video_seconds": f"{video_seconds}",
1446
+ "infer_steps": f"{args.infer_steps}",
1447
+ "guidance_scale": f"{args.guidance_scale}",
1448
+ "latent_window_size": f"{args.latent_window_size}",
1449
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
1450
+ "guidance_rescale": f"{args.guidance_rescale}",
1451
+ "sample_solver": f"{args.sample_solver}",
1452
+ "latent_window_size": f"{args.latent_window_size}",
1453
+ "fps": f"{args.fps}",
1454
+ }
1455
+ if args.negative_prompt is not None:
1456
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
1457
+
1458
+ sd = {"latent": latent.contiguous()}
1459
+ save_file(sd, latent_path, metadata=metadata)
1460
+ logger.info(f"Latent saved to: {latent_path}")
1461
+
1462
+ return latent_path
1463
+
1464
+
1465
+ def save_video(
1466
+ video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None
1467
+ ) -> str:
1468
+ """Save video to file
1469
+
1470
+ Args:
1471
+ video: Video tensor
1472
+ args: command line arguments
1473
+ original_base_name: Original base name (if latents are loaded from files)
1474
+
1475
+ Returns:
1476
+ str: Path to saved video file
1477
+ """
1478
+ save_path = args.save_path
1479
+ os.makedirs(save_path, exist_ok=True)
1480
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1481
+
1482
+ seed = args.seed
1483
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1484
+ latent_frames = "" if latent_frames is None else f"_{latent_frames}"
1485
+ video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4"
1486
+
1487
+ video = video.unsqueeze(0)
1488
+ save_videos_grid(video, video_path, fps=args.fps, rescale=True)
1489
+ logger.info(f"Video saved to: {video_path}")
1490
+
1491
+ return video_path
1492
+
1493
+
1494
+ def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
1495
+ """Save images to directory
1496
+
1497
+ Args:
1498
+ sample: Video tensor
1499
+ args: command line arguments
1500
+ original_base_name: Original base name (if latents are loaded from files)
1501
+
1502
+ Returns:
1503
+ str: Path to saved images directory
1504
+ """
1505
+ save_path = args.save_path
1506
+ os.makedirs(save_path, exist_ok=True)
1507
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1508
+
1509
+ seed = args.seed
1510
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1511
+ image_name = f"{time_flag}_{seed}{original_name}"
1512
+ sample = sample.unsqueeze(0)
1513
+ one_frame_mode = args.one_frame_inference is not None
1514
+ save_images_grid(sample, save_path, image_name, rescale=True, create_subdir=not one_frame_mode)
1515
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
1516
+
1517
+ return f"{save_path}/{image_name}"
1518
+
1519
+
1520
+ def save_output(
1521
+ args: argparse.Namespace,
1522
+ vae: AutoencoderKLCausal3D,
1523
+ latent: torch.Tensor,
1524
+ device: torch.device,
1525
+ original_base_names: Optional[List[str]] = None,
1526
+ ) -> None:
1527
+ """save output
1528
+
1529
+ Args:
1530
+ args: command line arguments
1531
+ vae: VAE model
1532
+ latent: latent tensor
1533
+ device: device to use
1534
+ original_base_names: original base names (if latents are loaded from files)
1535
+ """
1536
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
1537
+ height *= 8
1538
+ width *= 8
1539
+ # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}")
1540
+ if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
1541
+ # save latent
1542
+ save_latent(latent, args, height, width)
1543
+ if args.output_type == "latent":
1544
+ return
1545
+
1546
+ total_latent_sections = (args.video_seconds * 30) / (args.latent_window_size * 4)
1547
+ total_latent_sections = int(max(round(total_latent_sections), 1))
1548
+ video = decode_latent(
1549
+ args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent, device, args.one_frame_inference is not None
1550
+ )
1551
+
1552
+ if args.output_type == "video" or args.output_type == "both":
1553
+ # save video
1554
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1555
+ save_video(video, args, original_name)
1556
+
1557
+ elif args.output_type == "images" or args.output_type == "latent_images":
1558
+ # save images
1559
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1560
+ save_images(video, args, original_name)
1561
+
1562
+
1563
+ def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
1564
+ """Process multiple prompts for batch mode
1565
+
1566
+ Args:
1567
+ prompt_lines: List of prompt lines
1568
+ base_args: Base command line arguments
1569
+
1570
+ Returns:
1571
+ List[Dict]: List of prompt data dictionaries
1572
+ """
1573
+ prompts_data = []
1574
+
1575
+ for line in prompt_lines:
1576
+ line = line.strip()
1577
+ if not line or line.startswith("#"): # Skip empty lines and comments
1578
+ continue
1579
+
1580
+ # Parse prompt line and create override dictionary
1581
+ prompt_data = parse_prompt_line(line)
1582
+ logger.info(f"Parsed prompt data: {prompt_data}")
1583
+ prompts_data.append(prompt_data)
1584
+
1585
+ return prompts_data
1586
+
1587
+
1588
+ def load_shared_models(args: argparse.Namespace) -> Dict:
1589
+ """Load shared models for batch processing or interactive mode.
1590
+ Models are loaded to CPU to save memory.
1591
+
1592
+ Args:
1593
+ args: Base command line arguments
1594
+
1595
+ Returns:
1596
+ Dict: Dictionary of shared models
1597
+ """
1598
+ shared_models = {}
1599
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, "cpu")
1600
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
1601
+ feature_extractor, image_encoder = load_image_encoders(args)
1602
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
1603
+ shared_models["tokenizer1"] = tokenizer1
1604
+ shared_models["text_encoder1"] = text_encoder1
1605
+ shared_models["tokenizer2"] = tokenizer2
1606
+ shared_models["text_encoder2"] = text_encoder2
1607
+ shared_models["feature_extractor"] = feature_extractor
1608
+ shared_models["image_encoder"] = image_encoder
1609
+ shared_models["vae"] = vae
1610
+
1611
+ return shared_models
1612
+
1613
+
1614
+ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
1615
+ """Process multiple prompts with model reuse
1616
+
1617
+ Args:
1618
+ prompts_data: List of prompt data dictionaries
1619
+ args: Base command line arguments
1620
+ """
1621
+ if not prompts_data:
1622
+ logger.warning("No valid prompts found")
1623
+ return
1624
+
1625
+ # 1. Load configuration
1626
+ gen_settings = get_generation_settings(args)
1627
+ device = gen_settings.device
1628
+
1629
+ # 2. Load models to CPU in advance except for VAE and DiT
1630
+ shared_models = load_shared_models(args)
1631
+
1632
+ # 3. Generate for each prompt
1633
+ all_latents = []
1634
+ all_prompt_args = []
1635
+
1636
+ with torch.no_grad():
1637
+ for prompt_data in prompts_data:
1638
+ prompt = prompt_data["prompt"]
1639
+ prompt_args = apply_overrides(args, prompt_data)
1640
+ logger.info(f"Processing prompt: {prompt}")
1641
+
1642
+ try:
1643
+ vae, latent = generate(prompt_args, gen_settings, shared_models)
1644
+
1645
+ # Save latent if needed
1646
+ if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
1647
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
1648
+ height *= 8
1649
+ width *= 8
1650
+ save_latent(latent, prompt_args, height, width)
1651
+
1652
+ all_latents.append(latent)
1653
+ all_prompt_args.append(prompt_args)
1654
+ except Exception as e:
1655
+ logger.error(f"Error processing prompt: {prompt}. Error: {e}")
1656
+ continue
1657
+
1658
+ # 4. Free models
1659
+ if "model" in shared_models:
1660
+ del shared_models["model"]
1661
+ del shared_models["tokenizer1"]
1662
+ del shared_models["text_encoder1"]
1663
+ del shared_models["tokenizer2"]
1664
+ del shared_models["text_encoder2"]
1665
+ del shared_models["feature_extractor"]
1666
+ del shared_models["image_encoder"]
1667
+
1668
+ clean_memory_on_device(device)
1669
+ synchronize_device(device)
1670
+
1671
+ # 5. Decode latents if needed
1672
+ if args.output_type != "latent":
1673
+ logger.info("Decoding latents to videos/images")
1674
+ vae.to(device)
1675
+
1676
+ for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
1677
+ logger.info(f"Decoding output {i+1}/{len(all_latents)}")
1678
+
1679
+ # avoid saving latents again (ugly hack)
1680
+ if prompt_args.output_type == "both":
1681
+ prompt_args.output_type = "video"
1682
+ elif prompt_args.output_type == "latent_images":
1683
+ prompt_args.output_type = "images"
1684
+
1685
+ save_output(prompt_args, vae, latent[0], device)
1686
+
1687
+
1688
+ def process_interactive(args: argparse.Namespace) -> None:
1689
+ """Process prompts in interactive mode
1690
+
1691
+ Args:
1692
+ args: Base command line arguments
1693
+ """
1694
+ gen_settings = get_generation_settings(args)
1695
+ device = gen_settings.device
1696
+ shared_models = load_shared_models(args)
1697
+
1698
+ print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
1699
+
1700
+ try:
1701
+ while True:
1702
+ try:
1703
+ line = input("> ")
1704
+ if not line.strip():
1705
+ continue
1706
+
1707
+ # Parse prompt
1708
+ prompt_data = parse_prompt_line(line)
1709
+ prompt_args = apply_overrides(args, prompt_data)
1710
+
1711
+ # Generate latent
1712
+ vae, latent = generate(prompt_args, gen_settings, shared_models)
1713
+
1714
+ # Save latent and video
1715
+ save_output(prompt_args, vae, latent[0], device)
1716
+
1717
+ except KeyboardInterrupt:
1718
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
1719
+ continue
1720
+
1721
+ except EOFError:
1722
+ print("\nExiting interactive mode")
1723
+
1724
+
1725
+ def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
1726
+ device = torch.device(args.device)
1727
+
1728
+ dit_weight_dtype = None # default
1729
+ if args.fp8_scaled:
1730
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
1731
+ elif args.fp8:
1732
+ dit_weight_dtype = torch.float8_e4m3fn
1733
+
1734
+ logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}")
1735
+
1736
+ gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype)
1737
+ return gen_settings
1738
+
1739
+
1740
+ def main():
1741
+ # Parse arguments
1742
+ args = parse_args()
1743
+
1744
+ # Check if latents are provided
1745
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
1746
+
1747
+ # Set device
1748
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
1749
+ device = torch.device(device)
1750
+ logger.info(f"Using device: {device}")
1751
+ args.device = device
1752
+
1753
+ if latents_mode:
1754
+ # Original latent decode mode
1755
+ original_base_names = []
1756
+ latents_list = []
1757
+ seeds = []
1758
+
1759
+ # assert len(args.latent_path) == 1, "Only one latent path is supported for now"
1760
+
1761
+ for latent_path in args.latent_path:
1762
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
1763
+ seed = 0
1764
+
1765
+ if os.path.splitext(latent_path)[1] != ".safetensors":
1766
+ latents = torch.load(latent_path, map_location="cpu")
1767
+ else:
1768
+ latents = load_file(latent_path)["latent"]
1769
+ with safe_open(latent_path, framework="pt") as f:
1770
+ metadata = f.metadata()
1771
+ if metadata is None:
1772
+ metadata = {}
1773
+ logger.info(f"Loaded metadata: {metadata}")
1774
+
1775
+ if "seeds" in metadata:
1776
+ seed = int(metadata["seeds"])
1777
+ if "height" in metadata and "width" in metadata:
1778
+ height = int(metadata["height"])
1779
+ width = int(metadata["width"])
1780
+ args.video_size = [height, width]
1781
+ if "video_seconds" in metadata:
1782
+ args.video_seconds = float(metadata["video_seconds"])
1783
+
1784
+ seeds.append(seed)
1785
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
1786
+
1787
+ if latents.ndim == 5: # [BCTHW]
1788
+ latents = latents.squeeze(0) # [CTHW]
1789
+
1790
+ latents_list.append(latents)
1791
+
1792
+ # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
1793
+
1794
+ for i, latent in enumerate(latents_list):
1795
+ args.seed = seeds[i]
1796
+
1797
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
1798
+ save_output(args, vae, latent, device, original_base_names)
1799
+
1800
+ elif args.from_file:
1801
+ # Batch mode from file
1802
+
1803
+ # Read prompts from file
1804
+ with open(args.from_file, "r", encoding="utf-8") as f:
1805
+ prompt_lines = f.readlines()
1806
+
1807
+ # Process prompts
1808
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
1809
+ process_batch_prompts(prompts_data, args)
1810
+
1811
+ elif args.interactive:
1812
+ # Interactive mode
1813
+ process_interactive(args)
1814
+
1815
+ else:
1816
+ # Single prompt mode (original behavior)
1817
+
1818
+ # Generate latent
1819
+ gen_settings = get_generation_settings(args)
1820
+ vae, latent = generate(args, gen_settings)
1821
+ # print(f"Generated latent shape: {latent.shape}")
1822
+ if args.save_merged_model:
1823
+ return
1824
+
1825
+ # Save latent and video
1826
+ save_output(args, vae, latent[0], device)
1827
+
1828
+ logger.info("Done!")
1829
+
1830
+
1831
+ if __name__ == "__main__":
1832
+ main()
fpack_train_network.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import math
4
+ import time
5
+ from typing import Optional
6
+ from PIL import Image
7
+
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchvision.transforms.functional as TF
12
+ from tqdm import tqdm
13
+ from accelerate import Accelerator, init_empty_weights
14
+
15
+ from dataset import image_video_dataset
16
+ from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ARCHITECTURE_FRAMEPACK_FULL, load_video
17
+ from fpack_generate_video import decode_latent
18
+ from frame_pack import hunyuan
19
+ from frame_pack.clip_vision import hf_clip_vision_encode
20
+ from frame_pack.framepack_utils import load_image_encoders, load_text_encoder1, load_text_encoder2
21
+ from frame_pack.framepack_utils import load_vae as load_framepack_vae
22
+ from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
23
+ from frame_pack.k_diffusion_hunyuan import sample_hunyuan
24
+ from frame_pack.utils import crop_or_pad_yield_mask
25
+ from dataset.image_video_dataset import resize_image_to_bucket
26
+ from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
27
+
28
+ import logging
29
+
30
+ logger = logging.getLogger(__name__)
31
+ logging.basicConfig(level=logging.INFO)
32
+
33
+ from utils import model_utils
34
+ from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
35
+
36
+
37
+ class FramePackNetworkTrainer(NetworkTrainer):
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ # region model specific
42
+
43
+ @property
44
+ def architecture(self) -> str:
45
+ return ARCHITECTURE_FRAMEPACK
46
+
47
+ @property
48
+ def architecture_full_name(self) -> str:
49
+ return ARCHITECTURE_FRAMEPACK_FULL
50
+
51
+ def handle_model_specific_args(self, args):
52
+ self._i2v_training = True
53
+ self._control_training = False
54
+ self.default_guidance_scale = 10.0 # embeded guidance scale
55
+
56
+ def process_sample_prompts(
57
+ self,
58
+ args: argparse.Namespace,
59
+ accelerator: Accelerator,
60
+ sample_prompts: str,
61
+ ):
62
+ device = accelerator.device
63
+
64
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
65
+ prompts = load_prompts(sample_prompts)
66
+
67
+ # load text encoder
68
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
69
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
70
+ text_encoder2.to(device)
71
+
72
+ sample_prompts_te_outputs = {} # (prompt) -> (t1 embeds, t1 mask, t2 embeds)
73
+ for prompt_dict in prompts:
74
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
75
+ if p is None or p in sample_prompts_te_outputs:
76
+ continue
77
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
78
+ with torch.amp.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
79
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(p, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
80
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
81
+
82
+ llama_vec = llama_vec.to("cpu")
83
+ llama_attention_mask = llama_attention_mask.to("cpu")
84
+ clip_l_pooler = clip_l_pooler.to("cpu")
85
+ sample_prompts_te_outputs[p] = (llama_vec, llama_attention_mask, clip_l_pooler)
86
+ del text_encoder1, text_encoder2
87
+ clean_memory_on_device(device)
88
+
89
+ # image embedding for I2V training
90
+ feature_extractor, image_encoder = load_image_encoders(args)
91
+ image_encoder.to(device)
92
+
93
+ # encode image with image encoder
94
+ sample_prompts_image_embs = {}
95
+ for prompt_dict in prompts:
96
+ image_path = prompt_dict.get("image_path", None)
97
+ assert image_path is not None, "image_path should be set for I2V training"
98
+ if image_path in sample_prompts_image_embs:
99
+ continue
100
+
101
+ logger.info(f"Encoding image to image encoder context: {image_path}")
102
+
103
+ height = prompt_dict.get("height", 256)
104
+ width = prompt_dict.get("width", 256)
105
+
106
+ img = Image.open(image_path).convert("RGB")
107
+ img_np = np.array(img) # PIL to numpy, HWC
108
+ img_np = image_video_dataset.resize_image_to_bucket(img_np, (width, height)) # returns a numpy array
109
+
110
+ with torch.no_grad():
111
+ image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
112
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
113
+
114
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to("cpu")
115
+ sample_prompts_image_embs[image_path] = image_encoder_last_hidden_state
116
+
117
+ del image_encoder
118
+ clean_memory_on_device(device)
119
+
120
+ # prepare sample parameters
121
+ sample_parameters = []
122
+ for prompt_dict in prompts:
123
+ prompt_dict_copy = prompt_dict.copy()
124
+
125
+ p = prompt_dict.get("prompt", "")
126
+ llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
127
+ prompt_dict_copy["llama_vec"] = llama_vec
128
+ prompt_dict_copy["llama_attention_mask"] = llama_attention_mask
129
+ prompt_dict_copy["clip_l_pooler"] = clip_l_pooler
130
+
131
+ p = prompt_dict.get("negative_prompt", "")
132
+ llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
133
+ prompt_dict_copy["negative_llama_vec"] = llama_vec
134
+ prompt_dict_copy["negative_llama_attention_mask"] = llama_attention_mask
135
+ prompt_dict_copy["negative_clip_l_pooler"] = clip_l_pooler
136
+
137
+ p = prompt_dict.get("image_path", None)
138
+ prompt_dict_copy["image_encoder_last_hidden_state"] = sample_prompts_image_embs[p]
139
+
140
+ sample_parameters.append(prompt_dict_copy)
141
+
142
+ clean_memory_on_device(accelerator.device)
143
+ return sample_parameters
144
+
145
+ def do_inference(
146
+ self,
147
+ accelerator,
148
+ args,
149
+ sample_parameter,
150
+ vae,
151
+ dit_dtype,
152
+ transformer,
153
+ discrete_flow_shift,
154
+ sample_steps,
155
+ width,
156
+ height,
157
+ frame_count,
158
+ generator,
159
+ do_classifier_free_guidance,
160
+ guidance_scale,
161
+ cfg_scale,
162
+ image_path=None,
163
+ control_video_path=None,
164
+ ):
165
+ """architecture dependent inference"""
166
+ model: HunyuanVideoTransformer3DModelPacked = transformer
167
+ device = accelerator.device
168
+ if cfg_scale is None:
169
+ cfg_scale = 1.0
170
+ do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
171
+
172
+ # prepare parameters
173
+ one_frame_mode = args.one_frame
174
+ if one_frame_mode:
175
+ one_frame_inference = set()
176
+ for mode in sample_parameter["one_frame"].split(","):
177
+ one_frame_inference.add(mode.strip())
178
+ else:
179
+ one_frame_inference = None
180
+
181
+ latent_window_size = args.latent_window_size # default is 9
182
+ latent_f = (frame_count - 1) // 4 + 1
183
+ total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
184
+ if total_latent_sections < 1 and not one_frame_mode:
185
+ logger.warning(f"Not enough frames for FramePack: {latent_f}, minimum: {latent_window_size*4+1}")
186
+ return None
187
+
188
+ latent_f = total_latent_sections * latent_window_size + 1
189
+ actual_frame_count = (latent_f - 1) * 4 + 1
190
+ if actual_frame_count != frame_count:
191
+ logger.info(f"Frame count mismatch: {actual_frame_count} != {frame_count}, trimming to {actual_frame_count}")
192
+ frame_count = actual_frame_count
193
+ num_frames = latent_window_size * 4 - 3
194
+
195
+ # prepare start and control latent
196
+ def encode_image(path):
197
+ image = Image.open(path)
198
+ if image.mode == "RGBA":
199
+ alpha = image.split()[-1]
200
+ image = image.convert("RGB")
201
+ else:
202
+ alpha = None
203
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
204
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).unsqueeze(0).float() # 1, C, 1, H, W
205
+ image = image / 127.5 - 1 # -1 to 1
206
+ return hunyuan.vae_encode(image, vae).to("cpu"), alpha
207
+
208
+ # VAE encoding
209
+ logger.info(f"Encoding image to latent space")
210
+ vae.to(device)
211
+
212
+ start_latent, _ = (
213
+ encode_image(image_path) if image_path else torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32)
214
+ )
215
+
216
+ if one_frame_mode:
217
+ control_latents = []
218
+ control_alphas = []
219
+ if "control_image_path" in sample_parameter:
220
+ for control_image_path in sample_parameter["control_image_path"]:
221
+ control_latent, control_alpha = encode_image(control_image_path)
222
+ control_latents.append(control_latent)
223
+ control_alphas.append(control_alpha)
224
+ else:
225
+ control_latents = None
226
+ control_alphas = None
227
+
228
+ vae.to("cpu") # move VAE to CPU to save memory
229
+ clean_memory_on_device(device)
230
+
231
+ # sampilng
232
+ if not one_frame_mode:
233
+ f1_mode = args.f1
234
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
235
+
236
+ if not f1_mode:
237
+ total_generated_latent_frames = 0
238
+ latent_paddings = reversed(range(total_latent_sections))
239
+ else:
240
+ total_generated_latent_frames = 1
241
+ history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
242
+ latent_paddings = [0] * total_latent_sections
243
+
244
+ if total_latent_sections > 4:
245
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
246
+
247
+ latent_paddings = list(latent_paddings)
248
+ for loop_index in range(total_latent_sections):
249
+ latent_padding = latent_paddings[loop_index]
250
+
251
+ if not f1_mode:
252
+ is_last_section = latent_padding == 0
253
+ latent_padding_size = latent_padding * latent_window_size
254
+
255
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
256
+
257
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
258
+ (
259
+ clean_latent_indices_pre,
260
+ blank_indices,
261
+ latent_indices,
262
+ clean_latent_indices_post,
263
+ clean_latent_2x_indices,
264
+ clean_latent_4x_indices,
265
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
266
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
267
+
268
+ clean_latents_pre = start_latent.to(history_latents)
269
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
270
+ [1, 2, 16], dim=2
271
+ )
272
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
273
+ else:
274
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
275
+ (
276
+ clean_latent_indices_start,
277
+ clean_latent_4x_indices,
278
+ clean_latent_2x_indices,
279
+ clean_latent_1x_indices,
280
+ latent_indices,
281
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
282
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
283
+
284
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
285
+ [16, 2, 1], dim=2
286
+ )
287
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
288
+
289
+ # if use_teacache:
290
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
291
+ # else:
292
+ # transformer.initialize_teacache(enable_teacache=False)
293
+
294
+ llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
295
+ llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
296
+ clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
297
+ if cfg_scale == 1.0:
298
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
299
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
300
+ else:
301
+ llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
302
+ llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
303
+ clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
304
+ image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
305
+ device, dtype=torch.bfloat16
306
+ )
307
+
308
+ generated_latents = sample_hunyuan(
309
+ transformer=model,
310
+ sampler=args.sample_solver,
311
+ width=width,
312
+ height=height,
313
+ frames=num_frames,
314
+ real_guidance_scale=cfg_scale,
315
+ distilled_guidance_scale=guidance_scale,
316
+ guidance_rescale=0.0,
317
+ # shift=3.0,
318
+ num_inference_steps=sample_steps,
319
+ generator=generator,
320
+ prompt_embeds=llama_vec,
321
+ prompt_embeds_mask=llama_attention_mask,
322
+ prompt_poolers=clip_l_pooler,
323
+ negative_prompt_embeds=llama_vec_n,
324
+ negative_prompt_embeds_mask=llama_attention_mask_n,
325
+ negative_prompt_poolers=clip_l_pooler_n,
326
+ device=device,
327
+ dtype=torch.bfloat16,
328
+ image_embeddings=image_encoder_last_hidden_state,
329
+ latent_indices=latent_indices,
330
+ clean_latents=clean_latents,
331
+ clean_latent_indices=clean_latent_indices,
332
+ clean_latents_2x=clean_latents_2x,
333
+ clean_latent_2x_indices=clean_latent_2x_indices,
334
+ clean_latents_4x=clean_latents_4x,
335
+ clean_latent_4x_indices=clean_latent_4x_indices,
336
+ )
337
+
338
+ total_generated_latent_frames += int(generated_latents.shape[2])
339
+ if not f1_mode:
340
+ if is_last_section:
341
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
342
+ total_generated_latent_frames += 1
343
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
344
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
345
+ else:
346
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
347
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
348
+
349
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
350
+ else:
351
+ # one frame mode
352
+ sample_num_frames = 1
353
+ latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
354
+ latent_indices[:, 0] = latent_window_size # last of latent_window
355
+
356
+ def get_latent_mask(mask_image: Image.Image):
357
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
358
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
359
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
360
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
361
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (B, C, F, H, W)
362
+ mask_image = mask_image.to(torch.float32)
363
+ return mask_image
364
+
365
+ if control_latents is None or len(control_latents) == 0:
366
+ logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
367
+ control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
368
+
369
+ if "no_post" not in one_frame_inference:
370
+ # add zero latents as clean latents post
371
+ control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
372
+ logger.info(f"Add zero latents as clean latents post for one frame inference.")
373
+
374
+ # kisekaeichi and 1f-mc: both are using control images, but indices are different
375
+ clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
376
+ clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
377
+ if "no_post" not in one_frame_inference:
378
+ clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
379
+
380
+ # apply mask for control latents (clean latents)
381
+ for i in range(len(control_alphas)):
382
+ control_alpha = control_alphas[i]
383
+ if control_alpha is not None:
384
+ latent_mask = get_latent_mask(control_alpha)
385
+ logger.info(
386
+ f"Apply mask for clean latents 1x for {i+1}: shape: {latent_mask.shape}"
387
+ )
388
+ clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * latent_mask
389
+
390
+ for one_frame_param in one_frame_inference:
391
+ if one_frame_param.startswith("target_index="):
392
+ target_index = int(one_frame_param.split("=")[1])
393
+ latent_indices[:, 0] = target_index
394
+ logger.info(f"Set index for target: {target_index}")
395
+ elif one_frame_param.startswith("control_index="):
396
+ control_indices = one_frame_param.split("=")[1].split(";")
397
+ i = 0
398
+ while i < len(control_indices) and i < clean_latent_indices.shape[1]:
399
+ control_index = int(control_indices[i])
400
+ clean_latent_indices[:, i] = control_index
401
+ i += 1
402
+ logger.info(f"Set index for clean latent 1x: {control_indices}")
403
+
404
+ if "no_2x" in one_frame_inference:
405
+ clean_latents_2x = None
406
+ clean_latent_2x_indices = None
407
+ logger.info(f"No clean_latents_2x")
408
+ else:
409
+ clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
410
+ index = 1 + latent_window_size + 1
411
+ clean_latent_2x_indices = torch.arange(index, index + 2) # 2
412
+
413
+ if "no_4x" in one_frame_inference:
414
+ clean_latents_4x = None
415
+ clean_latent_4x_indices = None
416
+ logger.info(f"No clean_latents_4x")
417
+ else:
418
+ index = 1 + latent_window_size + 1 + 2
419
+ clean_latent_4x_indices = torch.arange(index, index + 16) # 16
420
+
421
+ logger.info(
422
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
423
+ )
424
+
425
+ # prepare conditioning inputs
426
+ llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
427
+ llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
428
+ clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
429
+ if cfg_scale == 1.0:
430
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
431
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
432
+ else:
433
+ llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
434
+ llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
435
+ clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
436
+ image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
437
+ device, dtype=torch.bfloat16
438
+ )
439
+
440
+ generated_latents = sample_hunyuan(
441
+ transformer=model,
442
+ sampler=args.sample_solver,
443
+ width=width,
444
+ height=height,
445
+ frames=1,
446
+ real_guidance_scale=cfg_scale,
447
+ distilled_guidance_scale=guidance_scale,
448
+ guidance_rescale=0.0,
449
+ # shift=3.0,
450
+ num_inference_steps=sample_steps,
451
+ generator=generator,
452
+ prompt_embeds=llama_vec,
453
+ prompt_embeds_mask=llama_attention_mask,
454
+ prompt_poolers=clip_l_pooler,
455
+ negative_prompt_embeds=llama_vec_n,
456
+ negative_prompt_embeds_mask=llama_attention_mask_n,
457
+ negative_prompt_poolers=clip_l_pooler_n,
458
+ device=device,
459
+ dtype=torch.bfloat16,
460
+ image_embeddings=image_encoder_last_hidden_state,
461
+ latent_indices=latent_indices,
462
+ clean_latents=clean_latents,
463
+ clean_latent_indices=clean_latent_indices,
464
+ clean_latents_2x=clean_latents_2x,
465
+ clean_latent_2x_indices=clean_latent_2x_indices,
466
+ clean_latents_4x=clean_latents_4x,
467
+ clean_latent_4x_indices=clean_latent_4x_indices,
468
+ )
469
+
470
+ real_history_latents = generated_latents.to(clean_latents)
471
+
472
+ # wait for 5 seconds until block swap is done
473
+ logger.info("Waiting for 5 seconds to finish block swap")
474
+ time.sleep(5)
475
+
476
+ gc.collect()
477
+ clean_memory_on_device(device)
478
+
479
+ video = decode_latent(
480
+ latent_window_size, total_latent_sections, args.bulk_decode, vae, real_history_latents, device, one_frame_mode
481
+ )
482
+ video = video.to("cpu", dtype=torch.float32).unsqueeze(0) # add batch dimension
483
+ video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
484
+ clean_memory_on_device(device)
485
+
486
+ return video
487
+
488
+ def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
489
+ vae_path = args.vae
490
+ logger.info(f"Loading VAE model from {vae_path}")
491
+ vae = load_framepack_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
492
+ return vae
493
+
494
+ def load_transformer(
495
+ self,
496
+ accelerator: Accelerator,
497
+ args: argparse.Namespace,
498
+ dit_path: str,
499
+ attn_mode: str,
500
+ split_attn: bool,
501
+ loading_device: str,
502
+ dit_weight_dtype: Optional[torch.dtype],
503
+ ):
504
+ logger.info(f"Loading DiT model from {dit_path}")
505
+ device = accelerator.device
506
+ model = load_packed_model(device, dit_path, attn_mode, loading_device, args.fp8_scaled, split_attn)
507
+ return model
508
+
509
+ def scale_shift_latents(self, latents):
510
+ # FramePack VAE includes scaling
511
+ return latents
512
+
513
+ def call_dit(
514
+ self,
515
+ args: argparse.Namespace,
516
+ accelerator: Accelerator,
517
+ transformer,
518
+ latents: torch.Tensor,
519
+ batch: dict[str, torch.Tensor],
520
+ noise: torch.Tensor,
521
+ noisy_model_input: torch.Tensor,
522
+ timesteps: torch.Tensor,
523
+ network_dtype: torch.dtype,
524
+ ):
525
+ model: HunyuanVideoTransformer3DModelPacked = transformer
526
+ device = accelerator.device
527
+ batch_size = latents.shape[0]
528
+
529
+ # maybe model.dtype is better than network_dtype...
530
+ distilled_guidance = torch.tensor([args.guidance_scale * 1000.0] * batch_size).to(device=device, dtype=network_dtype)
531
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
532
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
533
+ # for k, v in batch.items():
534
+ # if isinstance(v, torch.Tensor):
535
+ # print(f"{k}: {v.shape} {v.dtype} {v.device}")
536
+ with accelerator.autocast():
537
+ clean_latent_2x_indices = batch["clean_latent_2x_indices"] if "clean_latent_2x_indices" in batch else None
538
+ if clean_latent_2x_indices is not None:
539
+ clean_latent_2x = batch["latents_clean_2x"] if "latents_clean_2x" in batch else None
540
+ if clean_latent_2x is None:
541
+ clean_latent_2x = torch.zeros(
542
+ (batch_size, 16, 2, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
543
+ )
544
+ else:
545
+ clean_latent_2x = None
546
+
547
+ clean_latent_4x_indices = batch["clean_latent_4x_indices"] if "clean_latent_4x_indices" in batch else None
548
+ if clean_latent_4x_indices is not None:
549
+ clean_latent_4x = batch["latents_clean_4x"] if "latents_clean_4x" in batch else None
550
+ if clean_latent_4x is None:
551
+ clean_latent_4x = torch.zeros(
552
+ (batch_size, 16, 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
553
+ )
554
+ else:
555
+ clean_latent_4x = None
556
+
557
+ model_pred = model(
558
+ hidden_states=noisy_model_input,
559
+ timestep=timesteps,
560
+ encoder_hidden_states=batch["llama_vec"],
561
+ encoder_attention_mask=batch["llama_attention_mask"],
562
+ pooled_projections=batch["clip_l_pooler"],
563
+ guidance=distilled_guidance,
564
+ latent_indices=batch["latent_indices"],
565
+ clean_latents=batch["latents_clean"],
566
+ clean_latent_indices=batch["clean_latent_indices"],
567
+ clean_latents_2x=clean_latent_2x,
568
+ clean_latent_2x_indices=clean_latent_2x_indices,
569
+ clean_latents_4x=clean_latent_4x,
570
+ clean_latent_4x_indices=clean_latent_4x_indices,
571
+ image_embeddings=batch["image_embeddings"],
572
+ return_dict=False,
573
+ )
574
+ model_pred = model_pred[0] # returns tuple (model_pred, )
575
+
576
+ # flow matching loss
577
+ target = noise - latents
578
+
579
+ return model_pred, target
580
+
581
+ # endregion model specific
582
+
583
+
584
+ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
585
+ """FramePack specific parser setup"""
586
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
587
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
588
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
589
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
590
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
591
+ parser.add_argument(
592
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
593
+ )
594
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
595
+ parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
596
+ parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once in sample generation")
597
+ parser.add_argument("--f1", action="store_true", help="Use F1 sampling method for sample generation")
598
+ parser.add_argument("--one_frame", action="store_true", help="Use one frame sampling method for sample generation")
599
+ return parser
600
+
601
+
602
+ if __name__ == "__main__":
603
+ parser = setup_parser_common()
604
+ parser = framepack_setup_parser(parser)
605
+
606
+ args = parser.parse_args()
607
+ args = read_config_from_file(args, parser)
608
+
609
+ assert (
610
+ args.vae_dtype is None or args.vae_dtype == "float16"
611
+ ), "VAE dtype must be float16 / VAEのdtypeはfloat16でなければなりません"
612
+ args.vae_dtype = "float16" # fixed
613
+ args.dit_dtype = "bfloat16" # fixed
614
+ args.sample_solver = "unipc" # for sample generation, fixed to unipc
615
+
616
+ trainer = FramePackNetworkTrainer()
617
+ trainer.train(args)
frame_pack/__init__.py ADDED
File without changes
frame_pack/bucket_tools.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 640: [
3
+ (416, 960),
4
+ (448, 864),
5
+ (480, 832),
6
+ (512, 768),
7
+ (544, 704),
8
+ (576, 672),
9
+ (608, 640),
10
+ (640, 608),
11
+ (672, 576),
12
+ (704, 544),
13
+ (768, 512),
14
+ (832, 480),
15
+ (864, 448),
16
+ (960, 416),
17
+ ],
18
+ }
19
+
20
+
21
+ def find_nearest_bucket(h, w, resolution=640):
22
+ min_metric = float('inf')
23
+ best_bucket = None
24
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
25
+ metric = abs(h * bucket_w - w * bucket_h)
26
+ if metric <= min_metric:
27
+ min_metric = metric
28
+ best_bucket = (bucket_h, bucket_w)
29
+ return best_bucket
30
+
frame_pack/clip_vision.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(
10
+ device=image_encoder.device, dtype=image_encoder.dtype
11
+ )
12
+ image_encoder_output = image_encoder(**preprocessed)
13
+
14
+ return image_encoder_output
frame_pack/framepack_utils.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from types import SimpleNamespace
4
+ from typing import Optional, Union
5
+
6
+ import accelerate
7
+ from accelerate import Accelerator, init_empty_weights
8
+ import torch
9
+ from safetensors.torch import load_file
10
+ from transformers import (
11
+ LlamaTokenizerFast,
12
+ LlamaConfig,
13
+ LlamaModel,
14
+ CLIPTokenizer,
15
+ CLIPTextModel,
16
+ CLIPConfig,
17
+ SiglipImageProcessor,
18
+ SiglipVisionModel,
19
+ SiglipVisionConfig,
20
+ )
21
+
22
+ from utils.safetensors_utils import load_split_weights
23
+ from hunyuan_model.vae import load_vae as hunyuan_load_vae
24
+
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+ logging.basicConfig(level=logging.INFO)
29
+
30
+
31
+ def load_vae(
32
+ vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device]
33
+ ):
34
+ # single file and directory (contains 'vae') support
35
+ if os.path.isdir(vae_path):
36
+ vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors")
37
+ else:
38
+ vae_path = vae_path
39
+
40
+ vae_dtype = torch.float16 # if vae_dtype is None else str_to_dtype(vae_dtype)
41
+ vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path)
42
+ vae.eval()
43
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
44
+
45
+ # set chunk_size to CausalConv3d recursively
46
+ chunk_size = vae_chunk_size
47
+ if chunk_size is not None:
48
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
49
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
50
+
51
+ if vae_spatial_tile_sample_min_size is not None:
52
+ vae.enable_spatial_tiling(True)
53
+ vae.tile_sample_min_size = vae_spatial_tile_sample_min_size
54
+ vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8
55
+ logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}")
56
+ # elif vae_tiling:
57
+ else:
58
+ vae.enable_spatial_tiling(True)
59
+
60
+ return vae
61
+
62
+
63
+ # region Text Encoders
64
+
65
+ # Text Encoder configs are copied from HunyuanVideo repo
66
+
67
+ LLAMA_CONFIG = {
68
+ "architectures": ["LlamaModel"],
69
+ "attention_bias": False,
70
+ "attention_dropout": 0.0,
71
+ "bos_token_id": 128000,
72
+ "eos_token_id": 128001,
73
+ "head_dim": 128,
74
+ "hidden_act": "silu",
75
+ "hidden_size": 4096,
76
+ "initializer_range": 0.02,
77
+ "intermediate_size": 14336,
78
+ "max_position_embeddings": 8192,
79
+ "mlp_bias": False,
80
+ "model_type": "llama",
81
+ "num_attention_heads": 32,
82
+ "num_hidden_layers": 32,
83
+ "num_key_value_heads": 8,
84
+ "pretraining_tp": 1,
85
+ "rms_norm_eps": 1e-05,
86
+ "rope_scaling": None,
87
+ "rope_theta": 500000.0,
88
+ "tie_word_embeddings": False,
89
+ "torch_dtype": "float16",
90
+ "transformers_version": "4.46.3",
91
+ "use_cache": True,
92
+ "vocab_size": 128320,
93
+ }
94
+
95
+ CLIP_CONFIG = {
96
+ # "_name_or_path": "/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2",
97
+ "architectures": ["CLIPTextModel"],
98
+ "attention_dropout": 0.0,
99
+ "bos_token_id": 0,
100
+ "dropout": 0.0,
101
+ "eos_token_id": 2,
102
+ "hidden_act": "quick_gelu",
103
+ "hidden_size": 768,
104
+ "initializer_factor": 1.0,
105
+ "initializer_range": 0.02,
106
+ "intermediate_size": 3072,
107
+ "layer_norm_eps": 1e-05,
108
+ "max_position_embeddings": 77,
109
+ "model_type": "clip_text_model",
110
+ "num_attention_heads": 12,
111
+ "num_hidden_layers": 12,
112
+ "pad_token_id": 1,
113
+ "projection_dim": 768,
114
+ "torch_dtype": "float16",
115
+ "transformers_version": "4.48.0.dev0",
116
+ "vocab_size": 49408,
117
+ }
118
+
119
+
120
+ def load_text_encoder1(
121
+ args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None
122
+ ) -> tuple[LlamaTokenizerFast, LlamaModel]:
123
+ # single file, split file and directory (contains 'text_encoder') support
124
+ logger.info(f"Loading text encoder 1 tokenizer")
125
+ tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer")
126
+
127
+ logger.info(f"Loading text encoder 1 from {args.text_encoder1}")
128
+ if os.path.isdir(args.text_encoder1):
129
+ # load from directory, configs are in the directory
130
+ text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16)
131
+ else:
132
+ # load from file, we create the model with the appropriate config
133
+ config = LlamaConfig(**LLAMA_CONFIG)
134
+ with init_empty_weights():
135
+ text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16)
136
+
137
+ state_dict = load_split_weights(args.text_encoder1)
138
+
139
+ # support weights from ComfyUI
140
+ if "model.embed_tokens.weight" in state_dict:
141
+ for key in list(state_dict.keys()):
142
+ if key.startswith("model."):
143
+ new_key = key.replace("model.", "")
144
+ state_dict[new_key] = state_dict[key]
145
+ del state_dict[key]
146
+ if "tokenizer" in state_dict:
147
+ state_dict.pop("tokenizer")
148
+ if "lm_head.weight" in state_dict:
149
+ state_dict.pop("lm_head.weight")
150
+
151
+ # # support weights from ComfyUI
152
+ # if "tokenizer" in state_dict:
153
+ # state_dict.pop("tokenizer")
154
+
155
+ text_encoder1.load_state_dict(state_dict, strict=True, assign=True)
156
+
157
+ if fp8_llm:
158
+ org_dtype = text_encoder1.dtype
159
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
160
+ text_encoder1.to(device=device, dtype=torch.float8_e4m3fn)
161
+
162
+ # prepare LLM for fp8
163
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
164
+ def forward_hook(module):
165
+ def forward(hidden_states):
166
+ input_dtype = hidden_states.dtype
167
+ hidden_states = hidden_states.to(torch.float32)
168
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
169
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
170
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
171
+
172
+ return forward
173
+
174
+ for module in llama_model.modules():
175
+ if module.__class__.__name__ in ["Embedding"]:
176
+ # print("set", module.__class__.__name__, "to", target_dtype)
177
+ module.to(target_dtype)
178
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
179
+ # print("set", module.__class__.__name__, "hooks")
180
+ module.forward = forward_hook(module)
181
+
182
+ prepare_fp8(text_encoder1, org_dtype)
183
+ else:
184
+ text_encoder1.to(device)
185
+
186
+ text_encoder1.eval()
187
+ return tokenizer1, text_encoder1
188
+
189
+
190
+ def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]:
191
+ # single file and directory (contains 'text_encoder_2') support
192
+ logger.info(f"Loading text encoder 2 tokenizer")
193
+ tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2")
194
+
195
+ logger.info(f"Loading text encoder 2 from {args.text_encoder2}")
196
+ if os.path.isdir(args.text_encoder2):
197
+ # load from directory, configs are in the directory
198
+ text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16)
199
+ else:
200
+ # we only have one file, so we can load it directly
201
+ config = CLIPConfig(**CLIP_CONFIG)
202
+ with init_empty_weights():
203
+ text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16)
204
+
205
+ state_dict = load_file(args.text_encoder2)
206
+
207
+ text_encoder2.load_state_dict(state_dict, strict=True, assign=True)
208
+
209
+ text_encoder2.eval()
210
+ return tokenizer2, text_encoder2
211
+
212
+
213
+ # endregion
214
+
215
+ # region image encoder
216
+
217
+ # Siglip configs are copied from FramePack repo
218
+ FEATURE_EXTRACTOR_CONFIG = {
219
+ "do_convert_rgb": None,
220
+ "do_normalize": True,
221
+ "do_rescale": True,
222
+ "do_resize": True,
223
+ "image_mean": [0.5, 0.5, 0.5],
224
+ "image_processor_type": "SiglipImageProcessor",
225
+ "image_std": [0.5, 0.5, 0.5],
226
+ "processor_class": "SiglipProcessor",
227
+ "resample": 3,
228
+ "rescale_factor": 0.00392156862745098,
229
+ "size": {"height": 384, "width": 384},
230
+ }
231
+ IMAGE_ENCODER_CONFIG = {
232
+ "_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder",
233
+ "architectures": ["SiglipVisionModel"],
234
+ "attention_dropout": 0.0,
235
+ "hidden_act": "gelu_pytorch_tanh",
236
+ "hidden_size": 1152,
237
+ "image_size": 384,
238
+ "intermediate_size": 4304,
239
+ "layer_norm_eps": 1e-06,
240
+ "model_type": "siglip_vision_model",
241
+ "num_attention_heads": 16,
242
+ "num_channels": 3,
243
+ "num_hidden_layers": 27,
244
+ "patch_size": 14,
245
+ "torch_dtype": "bfloat16",
246
+ "transformers_version": "4.46.2",
247
+ }
248
+
249
+
250
+ def load_image_encoders(args):
251
+ logger.info(f"Loading image encoder feature extractor")
252
+ feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG)
253
+
254
+ # single file, split file and directory (contains 'image_encoder') support
255
+ logger.info(f"Loading image encoder from {args.image_encoder}")
256
+ if os.path.isdir(args.image_encoder):
257
+ # load from directory, configs are in the directory
258
+ image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16)
259
+ else:
260
+ # load from file, we create the model with the appropriate config
261
+ config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG)
262
+ with init_empty_weights():
263
+ image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16)
264
+
265
+ state_dict = load_file(args.image_encoder)
266
+
267
+ image_encoder.load_state_dict(state_dict, strict=True, assign=True)
268
+
269
+ image_encoder.eval()
270
+ return feature_extractor, image_encoder
271
+
272
+
273
+ # endregion
frame_pack/hunyuan.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code: https://github.com/lllyasviel/FramePack
2
+ # original license: Apache-2.0
3
+
4
+ import torch
5
+
6
+ # from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
7
+ # from diffusers_helper.utils import crop_or_pad_yield_mask
8
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
9
+ from hunyuan_model.text_encoder import PROMPT_TEMPLATE
10
+
11
+
12
+ @torch.no_grad()
13
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256, custom_system_prompt=None):
14
+ assert isinstance(prompt, str)
15
+
16
+ prompt = [prompt]
17
+
18
+ # LLAMA
19
+
20
+ # We can verify crop_start by checking the token count of the prompt:
21
+ # custom_system_prompt = (
22
+ # "Describe the video by detailing the following aspects: "
23
+ # "1. The main content and theme of the video."
24
+ # "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
25
+ # "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
26
+ # "4. background environment, light, style and atmosphere."
27
+ # "5. camera angles, movements, and transitions used in the video:"
28
+ # )
29
+ if custom_system_prompt is None:
30
+ prompt_llama = [PROMPT_TEMPLATE["dit-llm-encode-video"]["template"].format(p) for p in prompt]
31
+ crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"]["crop_start"]
32
+ else:
33
+ # count tokens for custom_system_prompt
34
+ full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{custom_system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
35
+ print(f"Custom system prompt: {full_prompt}")
36
+ system_prompt_tokens = tokenizer(full_prompt, return_tensors="pt", truncation=True).input_ids[0].shape[0]
37
+ print(f"Custom system prompt token count: {system_prompt_tokens}")
38
+ prompt_llama = [full_prompt + p + "<|eot_id|>" for p in prompt]
39
+ crop_start = system_prompt_tokens
40
+
41
+ llama_inputs = tokenizer(
42
+ prompt_llama,
43
+ padding="max_length",
44
+ max_length=max_length + crop_start,
45
+ truncation=True,
46
+ return_tensors="pt",
47
+ return_length=False,
48
+ return_overflowing_tokens=False,
49
+ return_attention_mask=True,
50
+ )
51
+
52
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
53
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
54
+ llama_attention_length = int(llama_attention_mask.sum())
55
+
56
+ llama_outputs = text_encoder(
57
+ input_ids=llama_input_ids,
58
+ attention_mask=llama_attention_mask,
59
+ output_hidden_states=True,
60
+ )
61
+
62
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
63
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
64
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
65
+
66
+ assert torch.all(llama_attention_mask.bool())
67
+
68
+ # CLIP
69
+
70
+ clip_l_input_ids = tokenizer_2(
71
+ prompt,
72
+ padding="max_length",
73
+ max_length=77,
74
+ truncation=True,
75
+ return_overflowing_tokens=False,
76
+ return_length=False,
77
+ return_tensors="pt",
78
+ ).input_ids
79
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
80
+
81
+ return llama_vec, clip_l_pooler
82
+
83
+
84
+ @torch.no_grad()
85
+ def vae_decode_fake(latents):
86
+ latent_rgb_factors = [
87
+ [-0.0395, -0.0331, 0.0445],
88
+ [0.0696, 0.0795, 0.0518],
89
+ [0.0135, -0.0945, -0.0282],
90
+ [0.0108, -0.0250, -0.0765],
91
+ [-0.0209, 0.0032, 0.0224],
92
+ [-0.0804, -0.0254, -0.0639],
93
+ [-0.0991, 0.0271, -0.0669],
94
+ [-0.0646, -0.0422, -0.0400],
95
+ [-0.0696, -0.0595, -0.0894],
96
+ [-0.0799, -0.0208, -0.0375],
97
+ [0.1166, 0.1627, 0.0962],
98
+ [0.1165, 0.0432, 0.0407],
99
+ [-0.2315, -0.1920, -0.1355],
100
+ [-0.0270, 0.0401, -0.0821],
101
+ [-0.0616, -0.0997, -0.0727],
102
+ [0.0249, -0.0469, -0.1703],
103
+ ] # From comfyui
104
+
105
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
106
+
107
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
108
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
109
+
110
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
111
+ images = images.clamp(0.0, 1.0)
112
+
113
+ return images
114
+
115
+
116
+ @torch.no_grad()
117
+ def vae_decode(latents, vae, image_mode=False) -> torch.Tensor:
118
+ latents = latents / vae.config.scaling_factor
119
+
120
+ if not image_mode:
121
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
122
+ else:
123
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
124
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
125
+ image = torch.cat(image, dim=2)
126
+
127
+ return image
128
+
129
+
130
+ @torch.no_grad()
131
+ def vae_encode(image, vae: AutoencoderKLCausal3D) -> torch.Tensor:
132
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
133
+ latents = latents * vae.config.scaling_factor
134
+ return latents
frame_pack/hunyuan_video_packed.py ADDED
@@ -0,0 +1,2038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code: https://github.com/lllyasviel/FramePack
2
+ # original license: Apache-2.0
3
+
4
+ import glob
5
+ import math
6
+ import numbers
7
+ import os
8
+ from types import SimpleNamespace
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import einops
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ from modules.custom_offloading_utils import ModelOffloader
18
+ from utils.safetensors_utils import load_split_weights
19
+ from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8
20
+ from accelerate import init_empty_weights
21
+
22
+ try:
23
+ # raise NotImplementedError
24
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
25
+
26
+ print("Xformers is installed!")
27
+ except:
28
+ print("Xformers is not installed!")
29
+ xformers_attn_func = None
30
+
31
+ try:
32
+ # raise NotImplementedError
33
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
34
+
35
+ print("Flash Attn is installed!")
36
+ except:
37
+ print("Flash Attn is not installed!")
38
+ flash_attn_varlen_func = None
39
+ flash_attn_func = None
40
+
41
+ try:
42
+ # raise NotImplementedError
43
+ from sageattention import sageattn_varlen, sageattn
44
+
45
+ print("Sage Attn is installed!")
46
+ except:
47
+ print("Sage Attn is not installed!")
48
+ sageattn_varlen = None
49
+ sageattn = None
50
+
51
+
52
+ import logging
53
+
54
+ logger = logging.getLogger(__name__)
55
+ logging.basicConfig(level=logging.INFO)
56
+
57
+ # region diffusers
58
+
59
+ # copied from diffusers with some modifications to minimize dependencies
60
+ # original code: https://github.com/huggingface/diffusers/
61
+ # original license: Apache-2.0
62
+
63
+ ACT2CLS = {
64
+ "swish": nn.SiLU,
65
+ "silu": nn.SiLU,
66
+ "mish": nn.Mish,
67
+ "gelu": nn.GELU,
68
+ "relu": nn.ReLU,
69
+ }
70
+
71
+
72
+ def get_activation(act_fn: str) -> nn.Module:
73
+ """Helper function to get activation function from string.
74
+
75
+ Args:
76
+ act_fn (str): Name of activation function.
77
+
78
+ Returns:
79
+ nn.Module: Activation function.
80
+ """
81
+
82
+ act_fn = act_fn.lower()
83
+ if act_fn in ACT2CLS:
84
+ return ACT2CLS[act_fn]()
85
+ else:
86
+ raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
87
+
88
+
89
+ def get_timestep_embedding(
90
+ timesteps: torch.Tensor,
91
+ embedding_dim: int,
92
+ flip_sin_to_cos: bool = False,
93
+ downscale_freq_shift: float = 1,
94
+ scale: float = 1,
95
+ max_period: int = 10000,
96
+ ):
97
+ """
98
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
99
+
100
+ Args
101
+ timesteps (torch.Tensor):
102
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
103
+ embedding_dim (int):
104
+ the dimension of the output.
105
+ flip_sin_to_cos (bool):
106
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
107
+ downscale_freq_shift (float):
108
+ Controls the delta between frequencies between dimensions
109
+ scale (float):
110
+ Scaling factor applied to the embeddings.
111
+ max_period (int):
112
+ Controls the maximum frequency of the embeddings
113
+ Returns
114
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
115
+ """
116
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
117
+
118
+ half_dim = embedding_dim // 2
119
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
120
+ exponent = exponent / (half_dim - downscale_freq_shift)
121
+
122
+ emb = torch.exp(exponent)
123
+ emb = timesteps[:, None].float() * emb[None, :]
124
+
125
+ # scale embeddings
126
+ emb = scale * emb
127
+
128
+ # concat sine and cosine embeddings
129
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
130
+
131
+ # flip sine and cosine embeddings
132
+ if flip_sin_to_cos:
133
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
134
+
135
+ # zero pad
136
+ if embedding_dim % 2 == 1:
137
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
138
+ return emb
139
+
140
+
141
+ class TimestepEmbedding(nn.Module):
142
+ def __init__(
143
+ self,
144
+ in_channels: int,
145
+ time_embed_dim: int,
146
+ act_fn: str = "silu",
147
+ out_dim: int = None,
148
+ post_act_fn: Optional[str] = None,
149
+ cond_proj_dim=None,
150
+ sample_proj_bias=True,
151
+ ):
152
+ super().__init__()
153
+
154
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
155
+
156
+ if cond_proj_dim is not None:
157
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
158
+ else:
159
+ self.cond_proj = None
160
+
161
+ self.act = get_activation(act_fn)
162
+
163
+ if out_dim is not None:
164
+ time_embed_dim_out = out_dim
165
+ else:
166
+ time_embed_dim_out = time_embed_dim
167
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
168
+
169
+ if post_act_fn is None:
170
+ self.post_act = None
171
+ else:
172
+ self.post_act = get_activation(post_act_fn)
173
+
174
+ def forward(self, sample, condition=None):
175
+ if condition is not None:
176
+ sample = sample + self.cond_proj(condition)
177
+ sample = self.linear_1(sample)
178
+
179
+ if self.act is not None:
180
+ sample = self.act(sample)
181
+
182
+ sample = self.linear_2(sample)
183
+
184
+ if self.post_act is not None:
185
+ sample = self.post_act(sample)
186
+ return sample
187
+
188
+
189
+ class Timesteps(nn.Module):
190
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
191
+ super().__init__()
192
+ self.num_channels = num_channels
193
+ self.flip_sin_to_cos = flip_sin_to_cos
194
+ self.downscale_freq_shift = downscale_freq_shift
195
+ self.scale = scale
196
+
197
+ def forward(self, timesteps):
198
+ t_emb = get_timestep_embedding(
199
+ timesteps,
200
+ self.num_channels,
201
+ flip_sin_to_cos=self.flip_sin_to_cos,
202
+ downscale_freq_shift=self.downscale_freq_shift,
203
+ scale=self.scale,
204
+ )
205
+ return t_emb
206
+
207
+
208
+ class FP32SiLU(nn.Module):
209
+ r"""
210
+ SiLU activation function with input upcasted to torch.float32.
211
+ """
212
+
213
+ def __init__(self):
214
+ super().__init__()
215
+
216
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
217
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
218
+
219
+
220
+ class GELU(nn.Module):
221
+ r"""
222
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
223
+
224
+ Parameters:
225
+ dim_in (`int`): The number of channels in the input.
226
+ dim_out (`int`): The number of channels in the output.
227
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
228
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
229
+ """
230
+
231
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
232
+ super().__init__()
233
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
234
+ self.approximate = approximate
235
+
236
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
237
+ # if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
238
+ # # fp16 gelu not supported on mps before torch 2.0
239
+ # return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
240
+ return F.gelu(gate, approximate=self.approximate)
241
+
242
+ def forward(self, hidden_states):
243
+ hidden_states = self.proj(hidden_states)
244
+ hidden_states = self.gelu(hidden_states)
245
+ return hidden_states
246
+
247
+
248
+ class PixArtAlphaTextProjection(nn.Module):
249
+ """
250
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
251
+
252
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
253
+ """
254
+
255
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
256
+ super().__init__()
257
+ if out_features is None:
258
+ out_features = hidden_size
259
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
260
+ if act_fn == "gelu_tanh":
261
+ self.act_1 = nn.GELU(approximate="tanh")
262
+ elif act_fn == "silu":
263
+ self.act_1 = nn.SiLU()
264
+ elif act_fn == "silu_fp32":
265
+ self.act_1 = FP32SiLU()
266
+ else:
267
+ raise ValueError(f"Unknown activation function: {act_fn}")
268
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
269
+
270
+ def forward(self, caption):
271
+ hidden_states = self.linear_1(caption)
272
+ hidden_states = self.act_1(hidden_states)
273
+ hidden_states = self.linear_2(hidden_states)
274
+ return hidden_states
275
+
276
+
277
+ class LayerNormFramePack(nn.LayerNorm):
278
+ # casting to dtype of input tensor is added
279
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
280
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
281
+
282
+
283
+ class FP32LayerNormFramePack(nn.LayerNorm):
284
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
285
+ origin_dtype = x.dtype
286
+ return torch.nn.functional.layer_norm(
287
+ x.float(),
288
+ self.normalized_shape,
289
+ self.weight.float() if self.weight is not None else None,
290
+ self.bias.float() if self.bias is not None else None,
291
+ self.eps,
292
+ ).to(origin_dtype)
293
+
294
+
295
+ class RMSNormFramePack(nn.Module):
296
+ r"""
297
+ RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
298
+
299
+ Args:
300
+ dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
301
+ eps (`float`): Small value to use when calculating the reciprocal of the square-root.
302
+ elementwise_affine (`bool`, defaults to `True`):
303
+ Boolean flag to denote if affine transformation should be applied.
304
+ bias (`bool`, defaults to False): If also training the `bias` param.
305
+ """
306
+
307
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
308
+ super().__init__()
309
+
310
+ self.eps = eps
311
+ self.elementwise_affine = elementwise_affine
312
+
313
+ if isinstance(dim, numbers.Integral):
314
+ dim = (dim,)
315
+
316
+ self.dim = torch.Size(dim)
317
+
318
+ self.weight = None
319
+ self.bias = None
320
+
321
+ if elementwise_affine:
322
+ self.weight = nn.Parameter(torch.ones(dim))
323
+ if bias:
324
+ self.bias = nn.Parameter(torch.zeros(dim))
325
+
326
+ def forward(self, hidden_states):
327
+ input_dtype = hidden_states.dtype
328
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
329
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
330
+
331
+ if self.weight is None:
332
+ return hidden_states.to(input_dtype)
333
+
334
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
335
+
336
+
337
+ class AdaLayerNormContinuousFramePack(nn.Module):
338
+ r"""
339
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
340
+
341
+ Args:
342
+ embedding_dim (`int`): Embedding dimension to use during projection.
343
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
344
+ elementwise_affine (`bool`, defaults to `True`):
345
+ Boolean flag to denote if affine transformation should be applied.
346
+ eps (`float`, defaults to 1e-5): Epsilon factor.
347
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
348
+ norm_type (`str`, defaults to `"layer_norm"`):
349
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ embedding_dim: int,
355
+ conditioning_embedding_dim: int,
356
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
357
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
358
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
359
+ # However, this is how it was implemented in the original code, and it's rather likely you should
360
+ # set `elementwise_affine` to False.
361
+ elementwise_affine=True,
362
+ eps=1e-5,
363
+ bias=True,
364
+ norm_type="layer_norm",
365
+ ):
366
+ super().__init__()
367
+ self.silu = nn.SiLU()
368
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
369
+ if norm_type == "layer_norm":
370
+ self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
371
+ elif norm_type == "rms_norm":
372
+ self.norm = RMSNormFramePack(embedding_dim, eps, elementwise_affine)
373
+ else:
374
+ raise ValueError(f"unknown norm_type {norm_type}")
375
+
376
+ def forward(self, x, conditioning_embedding):
377
+ emb = self.linear(self.silu(conditioning_embedding))
378
+ scale, shift = emb.chunk(2, dim=1)
379
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
380
+ return x
381
+
382
+
383
+ class LinearActivation(nn.Module):
384
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
385
+ super().__init__()
386
+
387
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
388
+ self.activation = get_activation(activation)
389
+
390
+ def forward(self, hidden_states):
391
+ hidden_states = self.proj(hidden_states)
392
+ return self.activation(hidden_states)
393
+
394
+
395
+ class FeedForward(nn.Module):
396
+ r"""
397
+ A feed-forward layer.
398
+
399
+ Parameters:
400
+ dim (`int`): The number of channels in the input.
401
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
402
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
403
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
404
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
405
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
406
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ dim: int,
412
+ dim_out: Optional[int] = None,
413
+ mult: int = 4,
414
+ dropout: float = 0.0,
415
+ activation_fn: str = "geglu",
416
+ final_dropout: bool = False,
417
+ inner_dim=None,
418
+ bias: bool = True,
419
+ ):
420
+ super().__init__()
421
+ if inner_dim is None:
422
+ inner_dim = int(dim * mult)
423
+ dim_out = dim_out if dim_out is not None else dim
424
+
425
+ # if activation_fn == "gelu":
426
+ # act_fn = GELU(dim, inner_dim, bias=bias)
427
+ if activation_fn == "gelu-approximate":
428
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
429
+ # elif activation_fn == "geglu":
430
+ # act_fn = GEGLU(dim, inner_dim, bias=bias)
431
+ # elif activation_fn == "geglu-approximate":
432
+ # act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
433
+ # elif activation_fn == "swiglu":
434
+ # act_fn = SwiGLU(dim, inner_dim, bias=bias)
435
+ elif activation_fn == "linear-silu":
436
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
437
+ else:
438
+ raise ValueError(f"Unknown activation function: {activation_fn}")
439
+
440
+ self.net = nn.ModuleList([])
441
+ # project in
442
+ self.net.append(act_fn)
443
+ # project dropout
444
+ self.net.append(nn.Dropout(dropout))
445
+ # project out
446
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
447
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
448
+ if final_dropout:
449
+ self.net.append(nn.Dropout(dropout))
450
+
451
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
452
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
453
+ # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
454
+ # deprecate("scale", "1.0.0", deprecation_message)
455
+ raise ValueError("scale is not supported in this version. Please remove it.")
456
+ for module in self.net:
457
+ hidden_states = module(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ # @maybe_allow_in_graph
462
+ class Attention(nn.Module):
463
+ r"""
464
+ Minimal copy of Attention class from diffusers.
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ query_dim: int,
470
+ cross_attention_dim: Optional[int] = None,
471
+ heads: int = 8,
472
+ dim_head: int = 64,
473
+ bias: bool = False,
474
+ qk_norm: Optional[str] = None,
475
+ added_kv_proj_dim: Optional[int] = None,
476
+ eps: float = 1e-5,
477
+ processor: Optional[any] = None,
478
+ out_dim: int = None,
479
+ context_pre_only=None,
480
+ pre_only=False,
481
+ ):
482
+ super().__init__()
483
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
484
+ self.inner_kv_dim = self.inner_dim # if kv_heads is None else dim_head * kv_heads
485
+ self.query_dim = query_dim
486
+ self.use_bias = bias
487
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
488
+ self.out_dim = out_dim if out_dim is not None else query_dim
489
+ self.out_context_dim = query_dim
490
+ self.context_pre_only = context_pre_only
491
+ self.pre_only = pre_only
492
+
493
+ self.scale = dim_head**-0.5
494
+ self.heads = out_dim // dim_head if out_dim is not None else heads
495
+
496
+ self.added_kv_proj_dim = added_kv_proj_dim
497
+
498
+ if qk_norm is None:
499
+ self.norm_q = None
500
+ self.norm_k = None
501
+ elif qk_norm == "rms_norm":
502
+ self.norm_q = RMSNormFramePack(dim_head, eps=eps)
503
+ self.norm_k = RMSNormFramePack(dim_head, eps=eps)
504
+ else:
505
+ raise ValueError(
506
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
507
+ )
508
+
509
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
510
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
511
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
512
+
513
+ self.added_proj_bias = True # added_proj_bias
514
+ if self.added_kv_proj_dim is not None:
515
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
516
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
517
+ if self.context_pre_only is not None:
518
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
519
+ else:
520
+ self.add_q_proj = None
521
+ self.add_k_proj = None
522
+ self.add_v_proj = None
523
+
524
+ if not self.pre_only:
525
+ self.to_out = nn.ModuleList([])
526
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=True))
527
+ # self.to_out.append(nn.Dropout(dropout))
528
+ self.to_out.append(nn.Identity()) # dropout=0.0
529
+ else:
530
+ self.to_out = None
531
+
532
+ if self.context_pre_only is not None and not self.context_pre_only:
533
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=True)
534
+ else:
535
+ self.to_add_out = None
536
+
537
+ if qk_norm is not None and added_kv_proj_dim is not None:
538
+ if qk_norm == "rms_norm":
539
+ self.norm_added_q = RMSNormFramePack(dim_head, eps=eps)
540
+ self.norm_added_k = RMSNormFramePack(dim_head, eps=eps)
541
+ else:
542
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`")
543
+ else:
544
+ self.norm_added_q = None
545
+ self.norm_added_k = None
546
+
547
+ # set attention processor
548
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
549
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
550
+ if processor is None:
551
+ processor = AttnProcessor2_0()
552
+ self.set_processor(processor)
553
+
554
+ def set_processor(self, processor: any) -> None:
555
+ self.processor = processor
556
+
557
+ def get_processor(self) -> any:
558
+ return self.processor
559
+
560
+ def forward(
561
+ self,
562
+ hidden_states: torch.Tensor,
563
+ encoder_hidden_states: Optional[torch.Tensor] = None,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ **cross_attention_kwargs,
566
+ ) -> torch.Tensor:
567
+ return self.processor(
568
+ self,
569
+ hidden_states,
570
+ encoder_hidden_states=encoder_hidden_states,
571
+ attention_mask=attention_mask,
572
+ **cross_attention_kwargs,
573
+ )
574
+
575
+ def prepare_attention_mask(
576
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
577
+ ) -> torch.Tensor:
578
+ r"""
579
+ Prepare the attention mask for the attention computation.
580
+
581
+ Args:
582
+ attention_mask (`torch.Tensor`):
583
+ The attention mask to prepare.
584
+ target_length (`int`):
585
+ The target length of the attention mask. This is the length of the attention mask after padding.
586
+ batch_size (`int`):
587
+ The batch size, which is used to repeat the attention mask.
588
+ out_dim (`int`, *optional*, defaults to `3`):
589
+ The output dimension of the attention mask. Can be either `3` or `4`.
590
+
591
+ Returns:
592
+ `torch.Tensor`: The prepared attention mask.
593
+ """
594
+ head_size = self.heads
595
+ if attention_mask is None:
596
+ return attention_mask
597
+
598
+ current_length: int = attention_mask.shape[-1]
599
+ if current_length != target_length:
600
+ if attention_mask.device.type == "mps":
601
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
602
+ # Instead, we can manually construct the padding tensor.
603
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
604
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
605
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
606
+ else:
607
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
608
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
609
+ # remaining_length: int = target_length - current_length
610
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
611
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
612
+
613
+ if out_dim == 3:
614
+ if attention_mask.shape[0] < batch_size * head_size:
615
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0, output_size=attention_mask.shape[0] * head_size)
616
+ elif out_dim == 4:
617
+ attention_mask = attention_mask.unsqueeze(1)
618
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1, output_size=attention_mask.shape[1] * head_size)
619
+
620
+ return attention_mask
621
+
622
+
623
+ class AttnProcessor2_0:
624
+ r"""
625
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
626
+ """
627
+
628
+ def __init__(self):
629
+ if not hasattr(F, "scaled_dot_product_attention"):
630
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
631
+
632
+ def __call__(
633
+ self,
634
+ attn: Attention,
635
+ hidden_states: torch.Tensor,
636
+ encoder_hidden_states: Optional[torch.Tensor] = None,
637
+ attention_mask: Optional[torch.Tensor] = None,
638
+ temb: Optional[torch.Tensor] = None,
639
+ *args,
640
+ **kwargs,
641
+ ) -> torch.Tensor:
642
+ input_ndim = hidden_states.ndim
643
+
644
+ if input_ndim == 4:
645
+ batch_size, channel, height, width = hidden_states.shape
646
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
647
+
648
+ batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
649
+
650
+ if attention_mask is not None:
651
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
652
+ # scaled_dot_product_attention expects attention_mask shape to be
653
+ # (batch, heads, source_length, target_length)
654
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
655
+
656
+ query = attn.to_q(hidden_states)
657
+ query_dtype = query.dtype # store dtype before potentially deleting query
658
+
659
+ if encoder_hidden_states is None:
660
+ encoder_hidden_states = hidden_states
661
+
662
+ key = attn.to_k(encoder_hidden_states)
663
+ value = attn.to_v(encoder_hidden_states)
664
+
665
+ inner_dim = key.shape[-1]
666
+ head_dim = inner_dim // attn.heads
667
+
668
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
669
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
670
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
671
+
672
+ if attn.norm_q is not None:
673
+ query = attn.norm_q(query)
674
+ if attn.norm_k is not None:
675
+ key = attn.norm_k(key)
676
+
677
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
678
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
679
+ del query, key, value, attention_mask # free memory
680
+
681
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
682
+ hidden_states = hidden_states.to(query_dtype) # use stored dtype
683
+
684
+ # linear proj
685
+ hidden_states = attn.to_out[0](hidden_states)
686
+ # dropout
687
+ hidden_states = attn.to_out[1](hidden_states)
688
+
689
+ if input_ndim == 4:
690
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
691
+
692
+ return hidden_states
693
+
694
+
695
+ # endregion diffusers
696
+
697
+
698
+ def pad_for_3d_conv(x, kernel_size):
699
+ b, c, t, h, w = x.shape
700
+ pt, ph, pw = kernel_size
701
+ pad_t = (pt - (t % pt)) % pt
702
+ pad_h = (ph - (h % ph)) % ph
703
+ pad_w = (pw - (w % pw)) % pw
704
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
705
+
706
+
707
+ def center_down_sample_3d(x, kernel_size):
708
+ # pt, ph, pw = kernel_size
709
+ # cp = (pt * ph * pw) // 2
710
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
711
+ # xc = xp[cp]
712
+ # return xc
713
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
714
+
715
+
716
+ def get_cu_seqlens(text_mask, img_len):
717
+ batch_size = text_mask.shape[0]
718
+ text_len = text_mask.sum(dim=1)
719
+ max_len = text_mask.shape[1] + img_len
720
+
721
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) # ensure device match
722
+
723
+ for i in range(batch_size):
724
+ s = text_len[i] + img_len
725
+ s1 = i * max_len + s
726
+ s2 = (i + 1) * max_len
727
+ cu_seqlens[2 * i + 1] = s1
728
+ cu_seqlens[2 * i + 2] = s2
729
+
730
+ return cu_seqlens
731
+
732
+
733
+ def apply_rotary_emb_transposed(x, freqs_cis):
734
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
735
+ del freqs_cis
736
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
737
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
738
+ del x_real, x_imag
739
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
740
+
741
+
742
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=None, split_attn=False):
743
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
744
+ if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
745
+ x = sageattn(q, k, v, tensor_layout="NHD")
746
+ return x
747
+
748
+ if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
749
+ x = flash_attn_func(q, k, v)
750
+ return x
751
+
752
+ if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
753
+ x = xformers_attn_func(q, k, v)
754
+ return x
755
+
756
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(
757
+ 1, 2
758
+ )
759
+ return x
760
+ if split_attn:
761
+ if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
762
+ x = torch.empty_like(q)
763
+ for i in range(q.size(0)):
764
+ x[i : i + 1] = sageattn(q[i : i + 1], k[i : i + 1], v[i : i + 1], tensor_layout="NHD")
765
+ return x
766
+
767
+ if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
768
+ x = torch.empty_like(q)
769
+ for i in range(q.size(0)):
770
+ x[i : i + 1] = flash_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
771
+ return x
772
+
773
+ if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
774
+ x = torch.empty_like(q)
775
+ for i in range(q.size(0)):
776
+ x[i : i + 1] = xformers_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
777
+ return x
778
+
779
+ q = q.transpose(1, 2)
780
+ k = k.transpose(1, 2)
781
+ v = v.transpose(1, 2)
782
+ x = torch.empty_like(q)
783
+ for i in range(q.size(0)):
784
+ x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(q[i : i + 1], k[i : i + 1], v[i : i + 1])
785
+ x = x.transpose(1, 2)
786
+ return x
787
+
788
+ batch_size = q.shape[0]
789
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
790
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
791
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
792
+ if attn_mode == "sageattn" or attn_mode is None and sageattn_varlen is not None:
793
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
794
+ del q, k, v # free memory
795
+ elif attn_mode == "flash" or attn_mode is None and flash_attn_varlen_func is not None:
796
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
797
+ del q, k, v # free memory
798
+ else:
799
+ raise NotImplementedError("No Attn Installed or batch_size > 1 is not supported in this configuration. Try `--split_attn`.")
800
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
801
+ return x
802
+
803
+
804
+ class HunyuanAttnProcessorFlashAttnDouble:
805
+ def __call__(
806
+ self,
807
+ attn: Attention,
808
+ hidden_states,
809
+ encoder_hidden_states,
810
+ attention_mask,
811
+ image_rotary_emb,
812
+ attn_mode: Optional[str] = None,
813
+ split_attn: Optional[bool] = False,
814
+ ):
815
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
816
+
817
+ # Project image latents
818
+ query = attn.to_q(hidden_states)
819
+ key = attn.to_k(hidden_states)
820
+ value = attn.to_v(hidden_states)
821
+ del hidden_states # free memory
822
+
823
+ query = query.unflatten(2, (attn.heads, -1))
824
+ key = key.unflatten(2, (attn.heads, -1))
825
+ value = value.unflatten(2, (attn.heads, -1))
826
+
827
+ query = attn.norm_q(query)
828
+ key = attn.norm_k(key)
829
+
830
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
831
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
832
+ del image_rotary_emb # free memory
833
+
834
+ # Project context (text/encoder) embeddings
835
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
836
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
837
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
838
+ txt_length = encoder_hidden_states.shape[1] # store length before deleting
839
+ del encoder_hidden_states # free memory
840
+
841
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
842
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
843
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
844
+
845
+ encoder_query = attn.norm_added_q(encoder_query)
846
+ encoder_key = attn.norm_added_k(encoder_key)
847
+
848
+ # Concatenate image and context q, k, v
849
+ query = torch.cat([query, encoder_query], dim=1)
850
+ key = torch.cat([key, encoder_key], dim=1)
851
+ value = torch.cat([value, encoder_value], dim=1)
852
+ del encoder_query, encoder_key, encoder_value # free memory
853
+
854
+ hidden_states_attn = attn_varlen_func(
855
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
856
+ )
857
+ del query, key, value # free memory
858
+ hidden_states_attn = hidden_states_attn.flatten(-2)
859
+
860
+ hidden_states, encoder_hidden_states = hidden_states_attn[:, :-txt_length], hidden_states_attn[:, -txt_length:]
861
+ del hidden_states_attn # free memory
862
+
863
+ # Apply output projections
864
+ hidden_states = attn.to_out[0](hidden_states)
865
+ hidden_states = attn.to_out[1](hidden_states) # Dropout/Identity
866
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
867
+
868
+ return hidden_states, encoder_hidden_states
869
+
870
+
871
+ class HunyuanAttnProcessorFlashAttnSingle:
872
+ def __call__(
873
+ self,
874
+ attn: Attention,
875
+ hidden_states,
876
+ encoder_hidden_states,
877
+ attention_mask,
878
+ image_rotary_emb,
879
+ attn_mode: Optional[str] = None,
880
+ split_attn: Optional[bool] = False,
881
+ ):
882
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
883
+ txt_length = encoder_hidden_states.shape[1] # Store text length
884
+
885
+ # Concatenate image and context inputs
886
+ hidden_states_cat = torch.cat([hidden_states, encoder_hidden_states], dim=1)
887
+ del hidden_states, encoder_hidden_states # free memory
888
+
889
+ # Project concatenated inputs
890
+ query = attn.to_q(hidden_states_cat)
891
+ key = attn.to_k(hidden_states_cat)
892
+ value = attn.to_v(hidden_states_cat)
893
+ del hidden_states_cat # free memory
894
+
895
+ query = query.unflatten(2, (attn.heads, -1))
896
+ key = key.unflatten(2, (attn.heads, -1))
897
+ value = value.unflatten(2, (attn.heads, -1))
898
+
899
+ query = attn.norm_q(query)
900
+ key = attn.norm_k(key)
901
+
902
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
903
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
904
+ del image_rotary_emb # free memory
905
+
906
+ hidden_states = attn_varlen_func(
907
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
908
+ )
909
+ del query, key, value # free memory
910
+ hidden_states = hidden_states.flatten(-2)
911
+
912
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
913
+
914
+ return hidden_states, encoder_hidden_states
915
+
916
+
917
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
918
+ def __init__(self, embedding_dim, pooled_projection_dim):
919
+ super().__init__()
920
+
921
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
922
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
923
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
924
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
925
+
926
+ def forward(self, timestep, guidance, pooled_projection):
927
+ timesteps_proj = self.time_proj(timestep)
928
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
929
+
930
+ guidance_proj = self.time_proj(guidance)
931
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
932
+
933
+ time_guidance_emb = timesteps_emb + guidance_emb
934
+
935
+ pooled_projections = self.text_embedder(pooled_projection)
936
+ conditioning = time_guidance_emb + pooled_projections
937
+
938
+ return conditioning
939
+
940
+
941
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
942
+ def __init__(self, embedding_dim, pooled_projection_dim):
943
+ super().__init__()
944
+
945
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
946
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
947
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
948
+
949
+ def forward(self, timestep, pooled_projection):
950
+ timesteps_proj = self.time_proj(timestep)
951
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
952
+
953
+ pooled_projections = self.text_embedder(pooled_projection)
954
+
955
+ conditioning = timesteps_emb + pooled_projections
956
+
957
+ return conditioning
958
+
959
+
960
+ class HunyuanVideoAdaNorm(nn.Module):
961
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
962
+ super().__init__()
963
+
964
+ out_features = out_features or 2 * in_features
965
+ self.linear = nn.Linear(in_features, out_features)
966
+ self.nonlinearity = nn.SiLU()
967
+
968
+ def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
969
+ temb = self.linear(self.nonlinearity(temb))
970
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
971
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
972
+ return gate_msa, gate_mlp
973
+
974
+
975
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
976
+ def __init__(
977
+ self,
978
+ num_attention_heads: int,
979
+ attention_head_dim: int,
980
+ mlp_width_ratio: float = 4.0,
981
+ mlp_drop_rate: float = 0.0,
982
+ attention_bias: bool = True,
983
+ ) -> None:
984
+ super().__init__()
985
+
986
+ hidden_size = num_attention_heads * attention_head_dim
987
+
988
+ self.norm1 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
989
+ self.attn = Attention(
990
+ query_dim=hidden_size,
991
+ cross_attention_dim=None,
992
+ heads=num_attention_heads,
993
+ dim_head=attention_head_dim,
994
+ bias=attention_bias,
995
+ )
996
+
997
+ self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
998
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
999
+
1000
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
1001
+
1002
+ def forward(
1003
+ self,
1004
+ hidden_states: torch.Tensor,
1005
+ temb: torch.Tensor,
1006
+ attention_mask: Optional[torch.Tensor] = None,
1007
+ ) -> torch.Tensor:
1008
+ norm_hidden_states = self.norm1(hidden_states)
1009
+
1010
+ # Self-attention
1011
+ attn_output = self.attn(
1012
+ hidden_states=norm_hidden_states,
1013
+ encoder_hidden_states=None,
1014
+ attention_mask=attention_mask,
1015
+ )
1016
+ del norm_hidden_states # free memory
1017
+
1018
+ gate_msa, gate_mlp = self.norm_out(temb)
1019
+ hidden_states = hidden_states + attn_output * gate_msa
1020
+ del attn_output, gate_msa # free memory
1021
+
1022
+ ff_output = self.ff(self.norm2(hidden_states))
1023
+ hidden_states = hidden_states + ff_output * gate_mlp
1024
+ del ff_output, gate_mlp # free memory
1025
+
1026
+ return hidden_states
1027
+
1028
+
1029
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
1030
+ def __init__(
1031
+ self,
1032
+ num_attention_heads: int,
1033
+ attention_head_dim: int,
1034
+ num_layers: int,
1035
+ mlp_width_ratio: float = 4.0,
1036
+ mlp_drop_rate: float = 0.0,
1037
+ attention_bias: bool = True,
1038
+ ) -> None:
1039
+ super().__init__()
1040
+
1041
+ self.refiner_blocks = nn.ModuleList(
1042
+ [
1043
+ HunyuanVideoIndividualTokenRefinerBlock(
1044
+ num_attention_heads=num_attention_heads,
1045
+ attention_head_dim=attention_head_dim,
1046
+ mlp_width_ratio=mlp_width_ratio,
1047
+ mlp_drop_rate=mlp_drop_rate,
1048
+ attention_bias=attention_bias,
1049
+ )
1050
+ for _ in range(num_layers)
1051
+ ]
1052
+ )
1053
+
1054
+ def forward(
1055
+ self,
1056
+ hidden_states: torch.Tensor,
1057
+ temb: torch.Tensor,
1058
+ attention_mask: Optional[torch.Tensor] = None,
1059
+ ) -> torch.Tensor:
1060
+ self_attn_mask = None
1061
+ if attention_mask is not None:
1062
+ batch_size = attention_mask.shape[0]
1063
+ seq_len = attention_mask.shape[1]
1064
+ attention_mask = attention_mask.to(hidden_states.device).bool()
1065
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
1066
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
1067
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
1068
+ self_attn_mask[:, :, :, 0] = True
1069
+
1070
+ for block in self.refiner_blocks:
1071
+ hidden_states = block(hidden_states, temb, self_attn_mask)
1072
+
1073
+ return hidden_states
1074
+
1075
+
1076
+ class HunyuanVideoTokenRefiner(nn.Module):
1077
+ def __init__(
1078
+ self,
1079
+ in_channels: int,
1080
+ num_attention_heads: int,
1081
+ attention_head_dim: int,
1082
+ num_layers: int,
1083
+ mlp_ratio: float = 4.0,
1084
+ mlp_drop_rate: float = 0.0,
1085
+ attention_bias: bool = True,
1086
+ ) -> None:
1087
+ super().__init__()
1088
+
1089
+ hidden_size = num_attention_heads * attention_head_dim
1090
+
1091
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(embedding_dim=hidden_size, pooled_projection_dim=in_channels)
1092
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
1093
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
1094
+ num_attention_heads=num_attention_heads,
1095
+ attention_head_dim=attention_head_dim,
1096
+ num_layers=num_layers,
1097
+ mlp_width_ratio=mlp_ratio,
1098
+ mlp_drop_rate=mlp_drop_rate,
1099
+ attention_bias=attention_bias,
1100
+ )
1101
+
1102
+ def forward(
1103
+ self,
1104
+ hidden_states: torch.Tensor,
1105
+ timestep: torch.LongTensor,
1106
+ attention_mask: Optional[torch.LongTensor] = None,
1107
+ ) -> torch.Tensor:
1108
+ if attention_mask is None:
1109
+ pooled_projections = hidden_states.mean(dim=1)
1110
+ else:
1111
+ original_dtype = hidden_states.dtype
1112
+ mask_float = attention_mask.float().unsqueeze(-1)
1113
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
1114
+ pooled_projections = pooled_projections.to(original_dtype)
1115
+
1116
+ temb = self.time_text_embed(timestep, pooled_projections)
1117
+ del pooled_projections # free memory
1118
+
1119
+ hidden_states = self.proj_in(hidden_states)
1120
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
1121
+ del temb, attention_mask # free memory
1122
+
1123
+ return hidden_states
1124
+
1125
+
1126
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
1127
+ def __init__(self, rope_dim, theta):
1128
+ super().__init__()
1129
+ self.DT, self.DY, self.DX = rope_dim
1130
+ self.theta = theta
1131
+ self.h_w_scaling_factor = 1.0
1132
+
1133
+ @torch.no_grad()
1134
+ def get_frequency(self, dim, pos):
1135
+ T, H, W = pos.shape
1136
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
1137
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
1138
+ return freqs.cos(), freqs.sin()
1139
+
1140
+ @torch.no_grad()
1141
+ def forward_inner(self, frame_indices, height, width, device):
1142
+ GT, GY, GX = torch.meshgrid(
1143
+ frame_indices.to(device=device, dtype=torch.float32),
1144
+ torch.arange(0, height, device=device, dtype=torch.float32) * self.h_w_scaling_factor,
1145
+ torch.arange(0, width, device=device, dtype=torch.float32) * self.h_w_scaling_factor,
1146
+ indexing="ij",
1147
+ )
1148
+
1149
+ FCT, FST = self.get_frequency(self.DT, GT)
1150
+ del GT # free memory
1151
+ FCY, FSY = self.get_frequency(self.DY, GY)
1152
+ del GY # free memory
1153
+ FCX, FSX = self.get_frequency(self.DX, GX)
1154
+ del GX # free memory
1155
+
1156
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
1157
+ del FCT, FCY, FCX, FST, FSY, FSX # free memory
1158
+
1159
+ # Return result already on the correct device
1160
+ return result # Shape (2 * total_dim / 2, T, H, W) -> (total_dim, T, H, W)
1161
+
1162
+ @torch.no_grad()
1163
+ def forward(self, frame_indices, height, width, device):
1164
+ frame_indices = frame_indices.unbind(0)
1165
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
1166
+ results = torch.stack(results, dim=0)
1167
+ return results
1168
+
1169
+
1170
+ class AdaLayerNormZero(nn.Module):
1171
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
1172
+ super().__init__()
1173
+ self.silu = nn.SiLU()
1174
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
1175
+ if norm_type == "layer_norm":
1176
+ self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
1177
+ else:
1178
+ raise ValueError(f"unknown norm_type {norm_type}")
1179
+
1180
+ def forward(
1181
+ self, x: torch.Tensor, emb: Optional[torch.Tensor] = None
1182
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1183
+ emb = emb.unsqueeze(-2)
1184
+ emb = self.linear(self.silu(emb))
1185
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
1186
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
1187
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
1188
+
1189
+
1190
+ class AdaLayerNormZeroSingle(nn.Module):
1191
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
1192
+ super().__init__()
1193
+
1194
+ self.silu = nn.SiLU()
1195
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
1196
+ if norm_type == "layer_norm":
1197
+ self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
1198
+ else:
1199
+ raise ValueError(f"unknown norm_type {norm_type}")
1200
+
1201
+ def forward(
1202
+ self,
1203
+ x: torch.Tensor,
1204
+ emb: Optional[torch.Tensor] = None,
1205
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1206
+ emb = emb.unsqueeze(-2)
1207
+ emb = self.linear(self.silu(emb))
1208
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
1209
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
1210
+ return x, gate_msa
1211
+
1212
+
1213
+ class AdaLayerNormContinuous(nn.Module):
1214
+ def __init__(
1215
+ self,
1216
+ embedding_dim: int,
1217
+ conditioning_embedding_dim: int,
1218
+ elementwise_affine=True,
1219
+ eps=1e-5,
1220
+ bias=True,
1221
+ norm_type="layer_norm",
1222
+ ):
1223
+ super().__init__()
1224
+ self.silu = nn.SiLU()
1225
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
1226
+ if norm_type == "layer_norm":
1227
+ self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
1228
+ else:
1229
+ raise ValueError(f"unknown norm_type {norm_type}")
1230
+
1231
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
1232
+ emb = emb.unsqueeze(-2)
1233
+ emb = self.linear(self.silu(emb))
1234
+ scale, shift = emb.chunk(2, dim=-1)
1235
+ del emb # free memory
1236
+ x = self.norm(x) * (1 + scale) + shift
1237
+ return x
1238
+
1239
+
1240
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
1241
+ def __init__(
1242
+ self,
1243
+ num_attention_heads: int,
1244
+ attention_head_dim: int,
1245
+ mlp_ratio: float = 4.0,
1246
+ qk_norm: str = "rms_norm",
1247
+ attn_mode: Optional[str] = None,
1248
+ split_attn: Optional[bool] = False,
1249
+ ) -> None:
1250
+ super().__init__()
1251
+
1252
+ hidden_size = num_attention_heads * attention_head_dim
1253
+ mlp_dim = int(hidden_size * mlp_ratio)
1254
+ self.attn_mode = attn_mode
1255
+ self.split_attn = split_attn
1256
+
1257
+ # Attention layer (pre_only=True means no output projection in Attention module itself)
1258
+ self.attn = Attention(
1259
+ query_dim=hidden_size,
1260
+ cross_attention_dim=None,
1261
+ dim_head=attention_head_dim,
1262
+ heads=num_attention_heads,
1263
+ out_dim=hidden_size,
1264
+ bias=True,
1265
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
1266
+ qk_norm=qk_norm,
1267
+ eps=1e-6,
1268
+ pre_only=True, # Crucial: Attn processor will return raw attention output
1269
+ )
1270
+
1271
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
1272
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
1273
+ self.act_mlp = nn.GELU(approximate="tanh")
1274
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
1275
+
1276
+ def forward(
1277
+ self,
1278
+ hidden_states: torch.Tensor,
1279
+ encoder_hidden_states: torch.Tensor,
1280
+ temb: torch.Tensor,
1281
+ attention_mask: Optional[torch.Tensor] = None,
1282
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1283
+ ) -> torch.Tensor:
1284
+ text_seq_length = encoder_hidden_states.shape[1]
1285
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
1286
+ del encoder_hidden_states # free memory
1287
+
1288
+ residual = hidden_states
1289
+
1290
+ # 1. Input normalization
1291
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
1292
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
1293
+
1294
+ norm_hidden_states, norm_encoder_hidden_states = (
1295
+ norm_hidden_states[:, :-text_seq_length, :],
1296
+ norm_hidden_states[:, -text_seq_length:, :],
1297
+ )
1298
+
1299
+ # 2. Attention
1300
+ attn_output, context_attn_output = self.attn(
1301
+ hidden_states=norm_hidden_states,
1302
+ encoder_hidden_states=norm_encoder_hidden_states,
1303
+ attention_mask=attention_mask,
1304
+ image_rotary_emb=image_rotary_emb,
1305
+ attn_mode=self.attn_mode,
1306
+ split_attn=self.split_attn,
1307
+ )
1308
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
1309
+ del norm_hidden_states, norm_encoder_hidden_states, context_attn_output # free memory
1310
+ del image_rotary_emb
1311
+
1312
+ # 3. Modulation and residual connection
1313
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
1314
+ del attn_output, mlp_hidden_states # free memory
1315
+ hidden_states = gate * self.proj_out(hidden_states)
1316
+ hidden_states = hidden_states + residual
1317
+
1318
+ hidden_states, encoder_hidden_states = (
1319
+ hidden_states[:, :-text_seq_length, :],
1320
+ hidden_states[:, -text_seq_length:, :],
1321
+ )
1322
+ return hidden_states, encoder_hidden_states
1323
+
1324
+
1325
+ class HunyuanVideoTransformerBlock(nn.Module):
1326
+ def __init__(
1327
+ self,
1328
+ num_attention_heads: int,
1329
+ attention_head_dim: int,
1330
+ mlp_ratio: float,
1331
+ qk_norm: str = "rms_norm",
1332
+ attn_mode: Optional[str] = None,
1333
+ split_attn: Optional[bool] = False,
1334
+ ) -> None:
1335
+ super().__init__()
1336
+
1337
+ hidden_size = num_attention_heads * attention_head_dim
1338
+ self.attn_mode = attn_mode
1339
+ self.split_attn = split_attn
1340
+
1341
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
1342
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
1343
+
1344
+ self.attn = Attention(
1345
+ query_dim=hidden_size,
1346
+ cross_attention_dim=None,
1347
+ added_kv_proj_dim=hidden_size,
1348
+ dim_head=attention_head_dim,
1349
+ heads=num_attention_heads,
1350
+ out_dim=hidden_size,
1351
+ context_pre_only=False,
1352
+ bias=True,
1353
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
1354
+ qk_norm=qk_norm,
1355
+ eps=1e-6,
1356
+ )
1357
+
1358
+ self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
1359
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
1360
+
1361
+ self.norm2_context = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
1362
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
1363
+
1364
+ def forward(
1365
+ self,
1366
+ hidden_states: torch.Tensor,
1367
+ encoder_hidden_states: torch.Tensor,
1368
+ temb: torch.Tensor,
1369
+ attention_mask: Optional[torch.Tensor] = None,
1370
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1372
+ # 1. Input normalization
1373
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
1374
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
1375
+ encoder_hidden_states, emb=temb
1376
+ )
1377
+
1378
+ # 2. Joint attention
1379
+ attn_output, context_attn_output = self.attn(
1380
+ hidden_states=norm_hidden_states,
1381
+ encoder_hidden_states=norm_encoder_hidden_states,
1382
+ attention_mask=attention_mask,
1383
+ image_rotary_emb=freqs_cis,
1384
+ attn_mode=self.attn_mode,
1385
+ split_attn=self.split_attn,
1386
+ )
1387
+ del norm_hidden_states, norm_encoder_hidden_states, freqs_cis # free memory
1388
+
1389
+ # 3. Modulation and residual connection
1390
+ hidden_states = hidden_states + attn_output * gate_msa
1391
+ del attn_output, gate_msa # free memory
1392
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
1393
+ del context_attn_output, c_gate_msa # free memory
1394
+
1395
+ norm_hidden_states = self.norm2(hidden_states)
1396
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
1397
+
1398
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1399
+ del shift_mlp, scale_mlp # free memory
1400
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
1401
+ del c_shift_mlp, c_scale_mlp # free memory
1402
+
1403
+ # 4. Feed-forward
1404
+ ff_output = self.ff(norm_hidden_states)
1405
+ del norm_hidden_states # free memory
1406
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
1407
+ del norm_encoder_hidden_states # free memory
1408
+
1409
+ hidden_states = hidden_states + gate_mlp * ff_output
1410
+ del ff_output, gate_mlp # free memory
1411
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
1412
+ del context_ff_output, c_gate_mlp # free memory
1413
+
1414
+ return hidden_states, encoder_hidden_states
1415
+
1416
+
1417
+ class ClipVisionProjection(nn.Module):
1418
+ def __init__(self, in_channels, out_channels):
1419
+ super().__init__()
1420
+ self.up = nn.Linear(in_channels, out_channels * 3)
1421
+ self.down = nn.Linear(out_channels * 3, out_channels)
1422
+
1423
+ def forward(self, x):
1424
+ projected_x = self.down(nn.functional.silu(self.up(x)))
1425
+ return projected_x
1426
+
1427
+
1428
+ class HunyuanVideoPatchEmbed(nn.Module):
1429
+ def __init__(self, patch_size, in_chans, embed_dim):
1430
+ super().__init__()
1431
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
1432
+
1433
+
1434
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
1435
+ def __init__(self, inner_dim):
1436
+ super().__init__()
1437
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
1438
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
1439
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
1440
+
1441
+ @torch.no_grad()
1442
+ def initialize_weight_from_another_conv3d(self, another_layer):
1443
+ weight = another_layer.weight.detach().clone()
1444
+ bias = another_layer.bias.detach().clone()
1445
+
1446
+ sd = {
1447
+ "proj.weight": weight.clone(),
1448
+ "proj.bias": bias.clone(),
1449
+ "proj_2x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=2, hk=2, wk=2) / 8.0,
1450
+ "proj_2x.bias": bias.clone(),
1451
+ "proj_4x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=4, hk=4, wk=4) / 64.0,
1452
+ "proj_4x.bias": bias.clone(),
1453
+ }
1454
+
1455
+ sd = {k: v.clone() for k, v in sd.items()}
1456
+
1457
+ self.load_state_dict(sd)
1458
+ return
1459
+
1460
+
1461
+ class HunyuanVideoTransformer3DModelPacked(nn.Module): # (PreTrainedModelMixin, GenerationMixin,
1462
+ # ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1463
+ # @register_to_config
1464
+ def __init__(
1465
+ self,
1466
+ in_channels: int = 16,
1467
+ out_channels: int = 16,
1468
+ num_attention_heads: int = 24,
1469
+ attention_head_dim: int = 128,
1470
+ num_layers: int = 20,
1471
+ num_single_layers: int = 40,
1472
+ num_refiner_layers: int = 2,
1473
+ mlp_ratio: float = 4.0,
1474
+ patch_size: int = 2,
1475
+ patch_size_t: int = 1,
1476
+ qk_norm: str = "rms_norm",
1477
+ guidance_embeds: bool = True,
1478
+ text_embed_dim: int = 4096,
1479
+ pooled_projection_dim: int = 768,
1480
+ rope_theta: float = 256.0,
1481
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
1482
+ has_image_proj=False,
1483
+ image_proj_dim=1152,
1484
+ has_clean_x_embedder=False,
1485
+ attn_mode: Optional[str] = None,
1486
+ split_attn: Optional[bool] = False,
1487
+ ) -> None:
1488
+ super().__init__()
1489
+
1490
+ inner_dim = num_attention_heads * attention_head_dim
1491
+ out_channels = out_channels or in_channels
1492
+ self.config_patch_size = patch_size
1493
+ self.config_patch_size_t = patch_size_t
1494
+
1495
+ # 1. Latent and condition embedders
1496
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
1497
+ self.context_embedder = HunyuanVideoTokenRefiner(
1498
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
1499
+ )
1500
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
1501
+
1502
+ self.clean_x_embedder = None
1503
+ self.image_projection = None
1504
+
1505
+ # 2. RoPE
1506
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
1507
+
1508
+ # 3. Dual stream transformer blocks
1509
+ self.transformer_blocks = nn.ModuleList(
1510
+ [
1511
+ HunyuanVideoTransformerBlock(
1512
+ num_attention_heads,
1513
+ attention_head_dim,
1514
+ mlp_ratio=mlp_ratio,
1515
+ qk_norm=qk_norm,
1516
+ attn_mode=attn_mode,
1517
+ split_attn=split_attn,
1518
+ )
1519
+ for _ in range(num_layers)
1520
+ ]
1521
+ )
1522
+
1523
+ # 4. Single stream transformer blocks
1524
+ self.single_transformer_blocks = nn.ModuleList(
1525
+ [
1526
+ HunyuanVideoSingleTransformerBlock(
1527
+ num_attention_heads,
1528
+ attention_head_dim,
1529
+ mlp_ratio=mlp_ratio,
1530
+ qk_norm=qk_norm,
1531
+ attn_mode=attn_mode,
1532
+ split_attn=split_attn,
1533
+ )
1534
+ for _ in range(num_single_layers)
1535
+ ]
1536
+ )
1537
+
1538
+ # 5. Output projection
1539
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
1540
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
1541
+
1542
+ self.inner_dim = inner_dim
1543
+ self.use_gradient_checkpointing = False
1544
+ self.enable_teacache = False
1545
+
1546
+ # if has_image_proj:
1547
+ # self.install_image_projection(image_proj_dim)
1548
+ self.image_projection = ClipVisionProjection(in_channels=image_proj_dim, out_channels=self.inner_dim)
1549
+ # self.config["has_image_proj"] = True
1550
+ # self.config["image_proj_dim"] = in_channels
1551
+
1552
+ # if has_clean_x_embedder:
1553
+ # self.install_clean_x_embedder()
1554
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
1555
+ # self.config["has_clean_x_embedder"] = True
1556
+
1557
+ self.high_quality_fp32_output_for_inference = True # False # change default to True
1558
+
1559
+ # Block swapping attributes (initialized to None)
1560
+ self.blocks_to_swap = None
1561
+ self.offloader_double = None
1562
+ self.offloader_single = None
1563
+
1564
+ # RoPE scaling
1565
+ self.rope_scaling_timestep_threshold: Optional[int] = None # scale RoPE above this timestep
1566
+ self.rope_scaling_factor: float = 1.0 # RoPE scaling factor
1567
+
1568
+ @property
1569
+ def device(self):
1570
+ return next(self.parameters()).device
1571
+
1572
+ @property
1573
+ def dtype(self):
1574
+ return next(self.parameters()).dtype
1575
+
1576
+ def enable_gradient_checkpointing(self):
1577
+ self.use_gradient_checkpointing = True
1578
+ print("Gradient checkpointing enabled for HunyuanVideoTransformer3DModelPacked.") # Logging
1579
+
1580
+ def disable_gradient_checkpointing(self):
1581
+ self.use_gradient_checkpointing = False
1582
+ print("Gradient checkpointing disabled for HunyuanVideoTransformer3DModelPacked.") # Logging
1583
+
1584
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
1585
+ self.enable_teacache = enable_teacache
1586
+ self.cnt = 0
1587
+ self.num_steps = num_steps
1588
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
1589
+ self.accumulated_rel_l1_distance = 0
1590
+ self.previous_modulated_input = None
1591
+ self.previous_residual = None
1592
+ self.teacache_rescale_func = np.poly1d([7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02])
1593
+ if enable_teacache:
1594
+ print(f"TeaCache enabled: num_steps={num_steps}, rel_l1_thresh={rel_l1_thresh}")
1595
+ else:
1596
+ print("TeaCache disabled.")
1597
+
1598
+ def gradient_checkpointing_method(self, block, *args):
1599
+ if self.use_gradient_checkpointing:
1600
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
1601
+ else:
1602
+ result = block(*args)
1603
+ return result
1604
+
1605
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
1606
+ self.blocks_to_swap = num_blocks
1607
+ self.num_double_blocks = len(self.transformer_blocks)
1608
+ self.num_single_blocks = len(self.single_transformer_blocks)
1609
+ double_blocks_to_swap = num_blocks // 2
1610
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
1611
+
1612
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
1613
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
1614
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
1615
+ )
1616
+
1617
+ self.offloader_double = ModelOffloader(
1618
+ "double",
1619
+ self.transformer_blocks,
1620
+ self.num_double_blocks,
1621
+ double_blocks_to_swap,
1622
+ supports_backward,
1623
+ device,
1624
+ # debug=True # Optional debugging
1625
+ )
1626
+ self.offloader_single = ModelOffloader(
1627
+ "single",
1628
+ self.single_transformer_blocks,
1629
+ self.num_single_blocks,
1630
+ single_blocks_to_swap,
1631
+ supports_backward,
1632
+ device, # , debug=True
1633
+ )
1634
+ print(
1635
+ f"HunyuanVideoTransformer3DModelPacked: Block swap enabled. Swapping {num_blocks} blocks, "
1636
+ + f"double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}, supports_backward: {supports_backward}."
1637
+ )
1638
+
1639
+ def switch_block_swap_for_inference(self):
1640
+ if self.blocks_to_swap and self.blocks_to_swap > 0:
1641
+ self.offloader_double.set_forward_only(True)
1642
+ self.offloader_single.set_forward_only(True)
1643
+ self.prepare_block_swap_before_forward()
1644
+ print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward only.")
1645
+
1646
+ def switch_block_swap_for_training(self):
1647
+ if self.blocks_to_swap and self.blocks_to_swap > 0:
1648
+ self.offloader_double.set_forward_only(False)
1649
+ self.offloader_single.set_forward_only(False)
1650
+ self.prepare_block_swap_before_forward()
1651
+ print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward and backward.")
1652
+
1653
+ def move_to_device_except_swap_blocks(self, device: torch.device):
1654
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
1655
+ if self.blocks_to_swap:
1656
+ saved_double_blocks = self.transformer_blocks
1657
+ saved_single_blocks = self.single_transformer_blocks
1658
+ self.transformer_blocks = None
1659
+ self.single_transformer_blocks = None
1660
+
1661
+ self.to(device)
1662
+
1663
+ if self.blocks_to_swap:
1664
+ self.transformer_blocks = saved_double_blocks
1665
+ self.single_transformer_blocks = saved_single_blocks
1666
+
1667
+ def prepare_block_swap_before_forward(self):
1668
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
1669
+ return
1670
+ self.offloader_double.prepare_block_devices_before_forward(self.transformer_blocks)
1671
+ self.offloader_single.prepare_block_devices_before_forward(self.single_transformer_blocks)
1672
+
1673
+ def enable_rope_scaling(self, timestep_threshold: Optional[int], rope_scaling_factor: float = 1.0):
1674
+ if timestep_threshold is not None and rope_scaling_factor > 0:
1675
+ self.rope_scaling_timestep_threshold = timestep_threshold
1676
+ self.rope_scaling_factor = rope_scaling_factor
1677
+ logger.info(f"RoPE scaling enabled: threshold={timestep_threshold}, scaling_factor={rope_scaling_factor}.")
1678
+ else:
1679
+ self.rope_scaling_timestep_threshold = None
1680
+ self.rope_scaling_factor = 1.0
1681
+ self.rope.h_w_scaling_factor = 1.0 # reset to default
1682
+ logger.info("RoPE scaling disabled.")
1683
+
1684
+ def process_input_hidden_states(
1685
+ self,
1686
+ latents,
1687
+ latent_indices=None,
1688
+ clean_latents=None,
1689
+ clean_latent_indices=None,
1690
+ clean_latents_2x=None,
1691
+ clean_latent_2x_indices=None,
1692
+ clean_latents_4x=None,
1693
+ clean_latent_4x_indices=None,
1694
+ ):
1695
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
1696
+ B, C, T, H, W = hidden_states.shape
1697
+
1698
+ if latent_indices is None:
1699
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
1700
+
1701
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
1702
+
1703
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
1704
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
1705
+
1706
+ if clean_latents is not None and clean_latent_indices is not None:
1707
+ clean_latents = clean_latents.to(hidden_states)
1708
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
1709
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
1710
+
1711
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
1712
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
1713
+
1714
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
1715
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
1716
+
1717
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
1718
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
1719
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
1720
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
1721
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
1722
+
1723
+ clean_latent_2x_rope_freqs = self.rope(
1724
+ frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device
1725
+ )
1726
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
1727
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
1728
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
1729
+
1730
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
1731
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
1732
+
1733
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
1734
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
1735
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
1736
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
1737
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
1738
+
1739
+ clean_latent_4x_rope_freqs = self.rope(
1740
+ frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device
1741
+ )
1742
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
1743
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
1744
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
1745
+
1746
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
1747
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
1748
+
1749
+ return hidden_states, rope_freqs
1750
+
1751
+ def forward(
1752
+ self,
1753
+ hidden_states,
1754
+ timestep,
1755
+ encoder_hidden_states,
1756
+ encoder_attention_mask,
1757
+ pooled_projections,
1758
+ guidance,
1759
+ latent_indices=None,
1760
+ clean_latents=None,
1761
+ clean_latent_indices=None,
1762
+ clean_latents_2x=None,
1763
+ clean_latent_2x_indices=None,
1764
+ clean_latents_4x=None,
1765
+ clean_latent_4x_indices=None,
1766
+ image_embeddings=None,
1767
+ attention_kwargs=None,
1768
+ return_dict=True,
1769
+ ):
1770
+
1771
+ if attention_kwargs is None:
1772
+ attention_kwargs = {}
1773
+
1774
+ # RoPE scaling: must be done before processing hidden states
1775
+ if self.rope_scaling_timestep_threshold is not None:
1776
+ if timestep >= self.rope_scaling_timestep_threshold:
1777
+ self.rope.h_w_scaling_factor = self.rope_scaling_factor
1778
+ else:
1779
+ self.rope.h_w_scaling_factor = 1.0
1780
+
1781
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
1782
+ p, p_t = self.config_patch_size, self.config_patch_size_t
1783
+ post_patch_num_frames = num_frames // p_t
1784
+ post_patch_height = height // p
1785
+ post_patch_width = width // p
1786
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
1787
+
1788
+ hidden_states, rope_freqs = self.process_input_hidden_states(
1789
+ hidden_states,
1790
+ latent_indices,
1791
+ clean_latents,
1792
+ clean_latent_indices,
1793
+ clean_latents_2x,
1794
+ clean_latent_2x_indices,
1795
+ clean_latents_4x,
1796
+ clean_latent_4x_indices,
1797
+ )
1798
+ del (
1799
+ latent_indices,
1800
+ clean_latents,
1801
+ clean_latent_indices,
1802
+ clean_latents_2x,
1803
+ clean_latent_2x_indices,
1804
+ clean_latents_4x,
1805
+ clean_latent_4x_indices,
1806
+ ) # free memory
1807
+
1808
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
1809
+ encoder_hidden_states = self.gradient_checkpointing_method(
1810
+ self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask
1811
+ )
1812
+
1813
+ if self.image_projection is not None:
1814
+ assert image_embeddings is not None, "You must use image embeddings!"
1815
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
1816
+ extra_attention_mask = torch.ones(
1817
+ (batch_size, extra_encoder_hidden_states.shape[1]),
1818
+ dtype=encoder_attention_mask.dtype,
1819
+ device=encoder_attention_mask.device,
1820
+ )
1821
+
1822
+ # must cat before (not after) encoder_hidden_states, due to attn masking
1823
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
1824
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
1825
+ del extra_encoder_hidden_states, extra_attention_mask # free memory
1826
+
1827
+ with torch.no_grad():
1828
+ if batch_size == 1:
1829
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
1830
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
1831
+ text_len = encoder_attention_mask.sum().item()
1832
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
1833
+ attention_mask = None, None, None, None
1834
+ else:
1835
+ img_seq_len = hidden_states.shape[1]
1836
+ txt_seq_len = encoder_hidden_states.shape[1]
1837
+
1838
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
1839
+ cu_seqlens_kv = cu_seqlens_q
1840
+ max_seqlen_q = img_seq_len + txt_seq_len
1841
+ max_seqlen_kv = max_seqlen_q
1842
+
1843
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
1844
+ del cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv # free memory
1845
+ del encoder_attention_mask # free memory
1846
+
1847
+ if self.enable_teacache:
1848
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
1849
+
1850
+ if self.cnt == 0 or self.cnt == self.num_steps - 1:
1851
+ should_calc = True
1852
+ self.accumulated_rel_l1_distance = 0
1853
+ else:
1854
+ curr_rel_l1 = (
1855
+ ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())
1856
+ .cpu()
1857
+ .item()
1858
+ )
1859
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
1860
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
1861
+
1862
+ if should_calc:
1863
+ self.accumulated_rel_l1_distance = 0
1864
+
1865
+ self.previous_modulated_input = modulated_inp
1866
+ self.cnt += 1
1867
+
1868
+ if self.cnt == self.num_steps:
1869
+ self.cnt = 0
1870
+
1871
+ if not should_calc:
1872
+ hidden_states = hidden_states + self.previous_residual
1873
+ else:
1874
+ ori_hidden_states = hidden_states.clone()
1875
+
1876
+ for block_id, block in enumerate(self.transformer_blocks):
1877
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1878
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1879
+ )
1880
+
1881
+ for block_id, block in enumerate(self.single_transformer_blocks):
1882
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1883
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1884
+ )
1885
+
1886
+ self.previous_residual = hidden_states - ori_hidden_states
1887
+ del ori_hidden_states # free memory
1888
+ else:
1889
+ for block_id, block in enumerate(self.transformer_blocks):
1890
+ if self.blocks_to_swap:
1891
+ self.offloader_double.wait_for_block(block_id)
1892
+
1893
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1894
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1895
+ )
1896
+
1897
+ if self.blocks_to_swap:
1898
+ self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id)
1899
+
1900
+ for block_id, block in enumerate(self.single_transformer_blocks):
1901
+ if self.blocks_to_swap:
1902
+ self.offloader_single.wait_for_block(block_id)
1903
+
1904
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1905
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1906
+ )
1907
+
1908
+ if self.blocks_to_swap:
1909
+ self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id)
1910
+
1911
+ del attention_mask, rope_freqs # free memory
1912
+ del encoder_hidden_states # free memory
1913
+
1914
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1915
+
1916
+ hidden_states = hidden_states[:, -original_context_length:, :]
1917
+
1918
+ if self.high_quality_fp32_output_for_inference:
1919
+ hidden_states = hidden_states.to(dtype=torch.float32)
1920
+ if self.proj_out.weight.dtype != torch.float32:
1921
+ self.proj_out.to(dtype=torch.float32)
1922
+
1923
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1924
+
1925
+ hidden_states = einops.rearrange(
1926
+ hidden_states,
1927
+ "b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)",
1928
+ t=post_patch_num_frames,
1929
+ h=post_patch_height,
1930
+ w=post_patch_width,
1931
+ pt=p_t,
1932
+ ph=p,
1933
+ pw=p,
1934
+ )
1935
+
1936
+ if return_dict:
1937
+ # return Transformer2DModelOutput(sample=hidden_states)
1938
+ return SimpleNamespace(sample=hidden_states)
1939
+
1940
+ return (hidden_states,)
1941
+
1942
+ def fp8_optimization(
1943
+ self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False
1944
+ ) -> dict[str, torch.Tensor]: # Return type hint added
1945
+ """
1946
+ Optimize the model state_dict with fp8.
1947
+
1948
+ Args:
1949
+ state_dict (dict[str, torch.Tensor]):
1950
+ The state_dict of the model.
1951
+ device (torch.device):
1952
+ The device to calculate the weight.
1953
+ move_to_device (bool):
1954
+ Whether to move the weight to the device after optimization.
1955
+ use_scaled_mm (bool):
1956
+ Whether to use scaled matrix multiplication for FP8.
1957
+ """
1958
+ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
1959
+ EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
1960
+
1961
+ # inplace optimization
1962
+ state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device)
1963
+
1964
+ # apply monkey patching
1965
+ apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm)
1966
+
1967
+ return state_dict
1968
+
1969
+
1970
+ def load_packed_model(
1971
+ device: Union[str, torch.device],
1972
+ dit_path: str,
1973
+ attn_mode: str,
1974
+ loading_device: Union[str, torch.device],
1975
+ fp8_scaled: bool = False,
1976
+ split_attn: bool = False,
1977
+ ) -> HunyuanVideoTransformer3DModelPacked:
1978
+ # TODO support split_attn
1979
+ device = torch.device(device)
1980
+ loading_device = torch.device(loading_device)
1981
+
1982
+ if os.path.isdir(dit_path):
1983
+ # we don't support from_pretrained for now, so loading safetensors directly
1984
+ safetensor_files = glob.glob(os.path.join(dit_path, "*.safetensors"))
1985
+ if len(safetensor_files) == 0:
1986
+ raise ValueError(f"Cannot find safetensors file in {dit_path}")
1987
+ # sort by name and take the first one
1988
+ safetensor_files.sort()
1989
+ dit_path = safetensor_files[0]
1990
+
1991
+ with init_empty_weights():
1992
+ logger.info(f"Creating HunyuanVideoTransformer3DModelPacked")
1993
+ model = HunyuanVideoTransformer3DModelPacked(
1994
+ attention_head_dim=128,
1995
+ guidance_embeds=True,
1996
+ has_clean_x_embedder=True,
1997
+ has_image_proj=True,
1998
+ image_proj_dim=1152,
1999
+ in_channels=16,
2000
+ mlp_ratio=4.0,
2001
+ num_attention_heads=24,
2002
+ num_layers=20,
2003
+ num_refiner_layers=2,
2004
+ num_single_layers=40,
2005
+ out_channels=16,
2006
+ patch_size=2,
2007
+ patch_size_t=1,
2008
+ pooled_projection_dim=768,
2009
+ qk_norm="rms_norm",
2010
+ rope_axes_dim=(16, 56, 56),
2011
+ rope_theta=256.0,
2012
+ text_embed_dim=4096,
2013
+ attn_mode=attn_mode,
2014
+ split_attn=split_attn,
2015
+ )
2016
+
2017
+ # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others)
2018
+ dit_loading_device = torch.device("cpu") if fp8_scaled else loading_device
2019
+ logger.info(f"Loading DiT model from {dit_path}, device={dit_loading_device}")
2020
+
2021
+ # load model weights with the specified dtype or as is
2022
+ sd = load_split_weights(dit_path, device=dit_loading_device, disable_mmap=True)
2023
+
2024
+ if fp8_scaled:
2025
+ # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap)
2026
+ logger.info(f"Optimizing model weights to fp8. This may take a while.")
2027
+ sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu")
2028
+
2029
+ if loading_device.type != "cpu":
2030
+ # make sure all the model weights are on the loading_device
2031
+ logger.info(f"Moving weights to {loading_device}")
2032
+ for key in sd.keys():
2033
+ sd[key] = sd[key].to(loading_device)
2034
+
2035
+ info = model.load_state_dict(sd, strict=True, assign=True)
2036
+ logger.info(f"Loaded DiT model from {dit_path}, info={info}")
2037
+
2038
+ return model
frame_pack/k_diffusion_hunyuan.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code: https://github.com/lllyasviel/FramePack
2
+ # original license: Apache-2.0
3
+
4
+ import torch
5
+ import math
6
+
7
+ # from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
8
+ # from diffusers_helper.k_diffusion.wrapper import fm_wrapper
9
+ # from diffusers_helper.utils import repeat_to_batch_size
10
+ from frame_pack.uni_pc_fm import sample_unipc
11
+ from frame_pack.wrapper import fm_wrapper
12
+ from frame_pack.utils import repeat_to_batch_size
13
+
14
+
15
+ def flux_time_shift(t, mu=1.15, sigma=1.0):
16
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
17
+
18
+
19
+ def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
20
+ k = (y2 - y1) / (x2 - x1)
21
+ b = y1 - k * x1
22
+ mu = k * context_length + b
23
+ mu = min(mu, math.log(exp_max))
24
+ return mu
25
+
26
+
27
+ def get_flux_sigmas_from_mu(n, mu):
28
+ sigmas = torch.linspace(1, 0, steps=n + 1)
29
+ sigmas = flux_time_shift(sigmas, mu=mu)
30
+ return sigmas
31
+
32
+
33
+ # @torch.inference_mode()
34
+ def sample_hunyuan(
35
+ transformer,
36
+ sampler="unipc",
37
+ initial_latent=None,
38
+ concat_latent=None,
39
+ strength=1.0,
40
+ width=512,
41
+ height=512,
42
+ frames=16,
43
+ real_guidance_scale=1.0,
44
+ distilled_guidance_scale=6.0,
45
+ guidance_rescale=0.0,
46
+ shift=None,
47
+ num_inference_steps=25,
48
+ batch_size=None,
49
+ generator=None,
50
+ prompt_embeds=None,
51
+ prompt_embeds_mask=None,
52
+ prompt_poolers=None,
53
+ negative_prompt_embeds=None,
54
+ negative_prompt_embeds_mask=None,
55
+ negative_prompt_poolers=None,
56
+ dtype=torch.bfloat16,
57
+ device=None,
58
+ negative_kwargs=None,
59
+ callback=None,
60
+ **kwargs,
61
+ ):
62
+ device = device or transformer.device
63
+
64
+ if batch_size is None:
65
+ batch_size = int(prompt_embeds.shape[0])
66
+
67
+ latents = torch.randn(
68
+ (batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device
69
+ ).to(device=device, dtype=torch.float32)
70
+
71
+ B, C, T, H, W = latents.shape
72
+ seq_length = T * H * W // 4 # 9*80*80//4 = 14400
73
+
74
+ if shift is None:
75
+ mu = calculate_flux_mu(seq_length, exp_max=7.0) # 1.9459... if seq_len is large, mu is clipped.
76
+ else:
77
+ mu = math.log(shift)
78
+
79
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
80
+
81
+ k_model = fm_wrapper(transformer)
82
+
83
+ if initial_latent is not None:
84
+ sigmas = sigmas * strength
85
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
86
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
87
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
88
+
89
+ if concat_latent is not None:
90
+ concat_latent = concat_latent.to(latents)
91
+
92
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
93
+
94
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
95
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
96
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
97
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
98
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
99
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
100
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
101
+
102
+ sampler_kwargs = dict(
103
+ dtype=dtype,
104
+ cfg_scale=real_guidance_scale,
105
+ cfg_rescale=guidance_rescale,
106
+ concat_latent=concat_latent,
107
+ positive=dict(
108
+ pooled_projections=prompt_poolers,
109
+ encoder_hidden_states=prompt_embeds,
110
+ encoder_attention_mask=prompt_embeds_mask,
111
+ guidance=distilled_guidance,
112
+ **kwargs,
113
+ ),
114
+ negative=dict(
115
+ pooled_projections=negative_prompt_poolers,
116
+ encoder_hidden_states=negative_prompt_embeds,
117
+ encoder_attention_mask=negative_prompt_embeds_mask,
118
+ guidance=distilled_guidance,
119
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
120
+ ),
121
+ )
122
+
123
+ if sampler == "unipc":
124
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
125
+ else:
126
+ raise NotImplementedError(f"Sampler {sampler} is not supported.")
127
+
128
+ return results
frame_pack/uni_pc_fm.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Better Flow Matching UniPC by Lvmin Zhang
2
+ # (c) 2025
3
+ # CC BY-SA 4.0
4
+ # Attribution-ShareAlike 4.0 International Licence
5
+
6
+
7
+ import torch
8
+
9
+ from tqdm.auto import trange
10
+
11
+
12
+ def expand_dims(v, dims):
13
+ return v[(...,) + (None,) * (dims - 1)]
14
+
15
+
16
+ class FlowMatchUniPC:
17
+ def __init__(self, model, extra_args, variant='bh1'):
18
+ self.model = model
19
+ self.variant = variant
20
+ self.extra_args = extra_args
21
+
22
+ def model_fn(self, x, t):
23
+ return self.model(x, t, **self.extra_args)
24
+
25
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
+ assert order <= len(model_prev_list)
27
+ dims = x.dim()
28
+
29
+ t_prev_0 = t_prev_list[-1]
30
+ lambda_prev_0 = - torch.log(t_prev_0)
31
+ lambda_t = - torch.log(t)
32
+ model_prev_0 = model_prev_list[-1]
33
+
34
+ h = lambda_t - lambda_prev_0
35
+
36
+ rks = []
37
+ D1s = []
38
+ for i in range(1, order):
39
+ t_prev_i = t_prev_list[-(i + 1)]
40
+ model_prev_i = model_prev_list[-(i + 1)]
41
+ lambda_prev_i = - torch.log(t_prev_i)
42
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
+ rks.append(rk)
44
+ D1s.append((model_prev_i - model_prev_0) / rk)
45
+
46
+ rks.append(1.)
47
+ rks = torch.tensor(rks, device=x.device)
48
+
49
+ R = []
50
+ b = []
51
+
52
+ hh = -h[0]
53
+ h_phi_1 = torch.expm1(hh)
54
+ h_phi_k = h_phi_1 / hh - 1
55
+
56
+ factorial_i = 1
57
+
58
+ if self.variant == 'bh1':
59
+ B_h = hh
60
+ elif self.variant == 'bh2':
61
+ B_h = torch.expm1(hh)
62
+ else:
63
+ raise NotImplementedError('Bad variant!')
64
+
65
+ for i in range(1, order + 1):
66
+ R.append(torch.pow(rks, i - 1))
67
+ b.append(h_phi_k * factorial_i / B_h)
68
+ factorial_i *= (i + 1)
69
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
70
+
71
+ R = torch.stack(R)
72
+ b = torch.tensor(b, device=x.device)
73
+
74
+ use_predictor = len(D1s) > 0
75
+
76
+ if use_predictor:
77
+ D1s = torch.stack(D1s, dim=1)
78
+ if order == 2:
79
+ rhos_p = torch.tensor([0.5], device=b.device)
80
+ else:
81
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
+ else:
83
+ D1s = None
84
+ rhos_p = None
85
+
86
+ if order == 1:
87
+ rhos_c = torch.tensor([0.5], device=b.device)
88
+ else:
89
+ rhos_c = torch.linalg.solve(R, b)
90
+
91
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
+
93
+ if use_predictor:
94
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
+ else:
96
+ pred_res = 0
97
+
98
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
+ model_t = self.model_fn(x_t, t)
100
+
101
+ if D1s is not None:
102
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
+ else:
104
+ corr_res = 0
105
+
106
+ D1_t = (model_t - model_prev_0)
107
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
+
109
+ return x_t, model_t
110
+
111
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
+ order = min(3, len(sigmas) - 2)
113
+ model_prev_list, t_prev_list = [], []
114
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
+ vec_t = sigmas[i].expand(x.shape[0])
116
+
117
+ with torch.no_grad():
118
+ if i == 0:
119
+ model_prev_list = [self.model_fn(x, vec_t)]
120
+ t_prev_list = [vec_t]
121
+ elif i < order:
122
+ init_order = i
123
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
124
+ model_prev_list.append(model_x)
125
+ t_prev_list.append(vec_t)
126
+ else:
127
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
128
+ model_prev_list.append(model_x)
129
+ t_prev_list.append(vec_t)
130
+
131
+ model_prev_list = model_prev_list[-order:]
132
+ t_prev_list = t_prev_list[-order:]
133
+
134
+ if callback is not None:
135
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
136
+
137
+ return model_prev_list[-1]
138
+
139
+
140
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
141
+ assert variant in ['bh1', 'bh2']
142
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
frame_pack/utils.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top : top + target_height, left : left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start : y_start + new_height, x_start : x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, "wt", encoding="utf-8") as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, "rt", encoding="utf-8") as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = [".lora_B.", "__zero__"]
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024**2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, "_forward_inside_frozen_module"):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, "*.safetensors"))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError("No file to resume!")
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(", ")
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ", ".join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, "(m n) c t h w -> t (m h) (n w) c", n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec="libx264", options={"crf": "0"})
280
+
281
+ # write tensor as .pt file
282
+ torch.save(x, output_filename.replace(".mp4", ".pt"))
283
+
284
+ return x
285
+
286
+
287
+ def save_bcthw_as_png(x, output_filename):
288
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
289
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
290
+ x = x.detach().cpu().to(torch.uint8)
291
+ x = einops.rearrange(x, "b c t h w -> c (b h) (t w)")
292
+ torchvision.io.write_png(x, output_filename)
293
+ return output_filename
294
+
295
+
296
+ def save_bchw_as_png(x, output_filename):
297
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
298
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
299
+ x = x.detach().cpu().to(torch.uint8)
300
+ x = einops.rearrange(x, "b c h w -> c h (b w)")
301
+ torchvision.io.write_png(x, output_filename)
302
+ return output_filename
303
+
304
+
305
+ def add_tensors_with_padding(tensor1, tensor2):
306
+ if tensor1.shape == tensor2.shape:
307
+ return tensor1 + tensor2
308
+
309
+ shape1 = tensor1.shape
310
+ shape2 = tensor2.shape
311
+
312
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
313
+
314
+ padded_tensor1 = torch.zeros(new_shape)
315
+ padded_tensor2 = torch.zeros(new_shape)
316
+
317
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
318
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
319
+
320
+ result = padded_tensor1 + padded_tensor2
321
+ return result
322
+
323
+
324
+ def print_free_mem():
325
+ torch.cuda.empty_cache()
326
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
327
+ free_mem_mb = free_mem / (1024**2)
328
+ total_mem_mb = total_mem / (1024**2)
329
+ print(f"Free memory: {free_mem_mb:.2f} MB")
330
+ print(f"Total memory: {total_mem_mb:.2f} MB")
331
+ return
332
+
333
+
334
+ def print_gpu_parameters(device, state_dict, log_count=1):
335
+ summary = {"device": device, "keys_count": len(state_dict)}
336
+
337
+ logged_params = {}
338
+ for i, (key, tensor) in enumerate(state_dict.items()):
339
+ if i >= log_count:
340
+ break
341
+ logged_params[key] = tensor.flatten()[:3].tolist()
342
+
343
+ summary["params"] = logged_params
344
+
345
+ print(str(summary))
346
+ return
347
+
348
+
349
+ def visualize_txt_as_img(width, height, text, font_path="font/DejaVuSans.ttf", size=18):
350
+ from PIL import Image, ImageDraw, ImageFont
351
+
352
+ txt = Image.new("RGB", (width, height), color="white")
353
+ draw = ImageDraw.Draw(txt)
354
+ font = ImageFont.truetype(font_path, size=size)
355
+
356
+ if text == "":
357
+ return np.array(txt)
358
+
359
+ # Split text into lines that fit within the image width
360
+ lines = []
361
+ words = text.split()
362
+ current_line = words[0]
363
+
364
+ for word in words[1:]:
365
+ line_with_word = f"{current_line} {word}"
366
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
367
+ current_line = line_with_word
368
+ else:
369
+ lines.append(current_line)
370
+ current_line = word
371
+
372
+ lines.append(current_line)
373
+
374
+ # Draw the text line by line
375
+ y = 0
376
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
377
+
378
+ for line in lines:
379
+ if y + line_height > height:
380
+ break # stop drawing if the next line will be outside the image
381
+ draw.text((0, y), line, fill="black", font=font)
382
+ y += line_height
383
+
384
+ return np.array(txt)
385
+
386
+
387
+ def blue_mark(x):
388
+ x = x.copy()
389
+ c = x[:, :, 2]
390
+ b = cv2.blur(c, (9, 9))
391
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
392
+ return x
393
+
394
+
395
+ def green_mark(x):
396
+ x = x.copy()
397
+ x[:, :, 2] = -1
398
+ x[:, :, 0] = -1
399
+ return x
400
+
401
+
402
+ def frame_mark(x):
403
+ x = x.copy()
404
+ x[:64] = -1
405
+ x[-64:] = -1
406
+ x[:, :8] = 1
407
+ x[:, -8:] = 1
408
+ return x
409
+
410
+
411
+ @torch.inference_mode()
412
+ def pytorch2numpy(imgs):
413
+ results = []
414
+ for x in imgs:
415
+ y = x.movedim(0, -1)
416
+ y = y * 127.5 + 127.5
417
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
418
+ results.append(y)
419
+ return results
420
+
421
+
422
+ @torch.inference_mode()
423
+ def numpy2pytorch(imgs):
424
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
425
+ h = h.movedim(-1, 1)
426
+ return h
427
+
428
+
429
+ @torch.no_grad()
430
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
431
+ if zero_out:
432
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
433
+ else:
434
+ return torch.cat([x, x[:count]], dim=0)
435
+
436
+
437
+ def weighted_mse(a, b, weight):
438
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
439
+
440
+
441
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
442
+ x = (x - x_min) / (x_max - x_min)
443
+ x = max(0.0, min(x, 1.0))
444
+ x = x**sigma
445
+ return y_min + x * (y_max - y_min)
446
+
447
+
448
+ def expand_to_dims(x, target_dims):
449
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
450
+
451
+
452
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
453
+ if tensor is None:
454
+ return None
455
+
456
+ first_dim = tensor.shape[0]
457
+
458
+ if first_dim == batch_size:
459
+ return tensor
460
+
461
+ if batch_size % first_dim != 0:
462
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
463
+
464
+ repeat_times = batch_size // first_dim
465
+
466
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
467
+
468
+
469
+ def dim5(x):
470
+ return expand_to_dims(x, 5)
471
+
472
+
473
+ def dim4(x):
474
+ return expand_to_dims(x, 4)
475
+
476
+
477
+ def dim3(x):
478
+ return expand_to_dims(x, 3)
479
+
480
+
481
+ def crop_or_pad_yield_mask(x, length):
482
+ B, F, C = x.shape
483
+ device = x.device
484
+ dtype = x.dtype
485
+
486
+ if F < length:
487
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
488
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
489
+ y[:, :F, :] = x
490
+ mask[:, :F] = True
491
+ return y, mask
492
+
493
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
494
+
495
+
496
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
497
+ original_length = int(x.shape[dim])
498
+
499
+ if original_length >= minimal_length:
500
+ return x
501
+
502
+ if zero_pad:
503
+ padding_shape = list(x.shape)
504
+ padding_shape[dim] = minimal_length - original_length
505
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
506
+ else:
507
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
508
+ last_element = x[idx]
509
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
510
+
511
+ return torch.cat([x, padding], dim=dim)
512
+
513
+
514
+ def lazy_positional_encoding(t, repeats=None):
515
+ if not isinstance(t, list):
516
+ t = [t]
517
+
518
+ from diffusers.models.embeddings import get_timestep_embedding
519
+
520
+ te = torch.tensor(t)
521
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
522
+
523
+ if repeats is None:
524
+ return te
525
+
526
+ te = te[:, None, :].expand(-1, repeats, -1)
527
+
528
+ return te
529
+
530
+
531
+ def state_dict_offset_merge(A, B, C=None):
532
+ result = {}
533
+ keys = A.keys()
534
+
535
+ for key in keys:
536
+ A_value = A[key]
537
+ B_value = B[key].to(A_value)
538
+
539
+ if C is None:
540
+ result[key] = A_value + B_value
541
+ else:
542
+ C_value = C[key].to(A_value)
543
+ result[key] = A_value + B_value - C_value
544
+
545
+ return result
546
+
547
+
548
+ def state_dict_weighted_merge(state_dicts, weights):
549
+ if len(state_dicts) != len(weights):
550
+ raise ValueError("Number of state dictionaries must match number of weights")
551
+
552
+ if not state_dicts:
553
+ return {}
554
+
555
+ total_weight = sum(weights)
556
+
557
+ if total_weight == 0:
558
+ raise ValueError("Sum of weights cannot be zero")
559
+
560
+ normalized_weights = [w / total_weight for w in weights]
561
+
562
+ keys = state_dicts[0].keys()
563
+ result = {}
564
+
565
+ for key in keys:
566
+ result[key] = state_dicts[0][key] * normalized_weights[0]
567
+
568
+ for i in range(1, len(state_dicts)):
569
+ state_dict_value = state_dicts[i][key].to(result[key])
570
+ result[key] += state_dict_value * normalized_weights[i]
571
+
572
+ return result
573
+
574
+
575
+ def group_files_by_folder(all_files):
576
+ grouped_files = {}
577
+
578
+ for file in all_files:
579
+ folder_name = os.path.basename(os.path.dirname(file))
580
+ if folder_name not in grouped_files:
581
+ grouped_files[folder_name] = []
582
+ grouped_files[folder_name].append(file)
583
+
584
+ list_of_lists = list(grouped_files.values())
585
+ return list_of_lists
586
+
587
+
588
+ def generate_timestamp():
589
+ now = datetime.datetime.now()
590
+ timestamp = now.strftime("%y%m%d_%H%M%S")
591
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
592
+ random_number = random.randint(0, 9999)
593
+ return f"{timestamp}_{milliseconds}_{random_number}"
594
+
595
+
596
+ def write_PIL_image_with_png_info(image, metadata, path):
597
+ from PIL.PngImagePlugin import PngInfo
598
+
599
+ png_info = PngInfo()
600
+ for key, value in metadata.items():
601
+ png_info.add_text(key, value)
602
+
603
+ image.save(path, "PNG", pnginfo=png_info)
604
+ return image
605
+
606
+
607
+ def torch_safe_save(content, path):
608
+ torch.save(content, path + "_tmp")
609
+ os.replace(path + "_tmp", path)
610
+ return path
611
+
612
+
613
+ def move_optimizer_to_device(optimizer, device):
614
+ for state in optimizer.state.values():
615
+ for k, v in state.items():
616
+ if isinstance(v, torch.Tensor):
617
+ state[k] = v.to(device)
frame_pack/wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def append_dims(x, target_dims):
5
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
6
+
7
+
8
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
+ if guidance_rescale == 0:
10
+ return noise_cfg
11
+
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
+ return noise_cfg
17
+
18
+
19
+ def fm_wrapper(transformer, t_scale=1000.0):
20
+ def k_model(x, sigma, **extra_args):
21
+ dtype = extra_args['dtype']
22
+ cfg_scale = extra_args['cfg_scale']
23
+ cfg_rescale = extra_args['cfg_rescale']
24
+ concat_latent = extra_args['concat_latent']
25
+
26
+ original_dtype = x.dtype
27
+ sigma = sigma.float()
28
+
29
+ x = x.to(dtype)
30
+ timestep = (sigma * t_scale).to(dtype)
31
+
32
+ if concat_latent is None:
33
+ hidden_states = x
34
+ else:
35
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
+
37
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
+
39
+ if cfg_scale == 1.0:
40
+ pred_negative = torch.zeros_like(pred_positive)
41
+ else:
42
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
+
44
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
+
47
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
+
49
+ return x0.to(dtype=original_dtype)
50
+
51
+ return k_model
hunyuan_model/fp8_optimization.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on ComfyUI's and MinusZoneAI's fp8_linear optimization
2
+ #further borrowed from HunyuanVideoWrapper for Musubi Tuner
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def fp8_linear_forward(cls, original_dtype, input):
7
+ weight_dtype = cls.weight.dtype
8
+ if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
9
+ if len(input.shape) == 3:
10
+ target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn
11
+ inn = input.reshape(-1, input.shape[2]).to(target_dtype)
12
+ w = cls.weight.t()
13
+
14
+ scale = torch.ones((1), device=input.device, dtype=torch.float32)
15
+ bias = cls.bias.to(original_dtype) if cls.bias is not None else None
16
+
17
+ if bias is not None:
18
+ o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale)
19
+ else:
20
+ o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale)
21
+
22
+ if isinstance(o, tuple):
23
+ o = o[0]
24
+
25
+ return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
26
+ else:
27
+ return cls.original_forward(input.to(original_dtype))
28
+ else:
29
+ return cls.original_forward(input)
30
+
31
+ def convert_fp8_linear(module, original_dtype, params_to_keep={}):
32
+ setattr(module, "fp8_matmul_enabled", True)
33
+
34
+ for name, module in module.named_modules():
35
+ if not any(keyword in name for keyword in params_to_keep):
36
+ if isinstance(module, nn.Linear):
37
+ original_forward = module.forward
38
+ setattr(module, "original_forward", original_forward)
39
+ setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
hv_generate_video.py CHANGED
@@ -25,6 +25,7 @@ from hunyuan_model.text_encoder import TextEncoder
25
  from hunyuan_model.text_encoder import PROMPT_TEMPLATE
26
  from hunyuan_model.vae import load_vae
27
  from hunyuan_model.models import load_transformer, get_rotary_pos_embed
 
28
  from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
29
  from networks import lora
30
 
@@ -313,23 +314,6 @@ def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=Fal
313
  # endregion
314
 
315
 
316
- def load_images(image_dir, video_length, bucket_reso):
317
- image_files = glob_images(image_dir)
318
- if len(image_files) == 0:
319
- raise ValueError(f"No image files found in {image_dir}")
320
- if len(image_files) < video_length:
321
- raise ValueError(f"Number of images in {image_dir} is less than {video_length}")
322
-
323
- image_files.sort()
324
- images = []
325
- for image_file in image_files[:video_length]:
326
- image = Image.open(image_file)
327
- image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array
328
- images.append(image)
329
-
330
- return images
331
-
332
-
333
  def prepare_vae(args, device):
334
  vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
335
  vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
@@ -479,6 +463,15 @@ def parse_args():
479
  parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
480
  parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
481
  parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
 
 
 
 
 
 
 
 
 
482
 
483
  args = parser.parse_args()
484
 
@@ -488,6 +481,9 @@ def parse_args():
488
 
489
  # update dit_weight based on model_base if not exists
490
 
 
 
 
491
  return args
492
 
493
 
@@ -573,12 +569,7 @@ def main():
573
  if args.video_path is not None:
574
  # v2v inference
575
  logger.info(f"Video2Video inference: {args.video_path}")
576
-
577
- if os.path.isfile(args.video_path):
578
- video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
579
- else:
580
- video = load_images(args.video_path, video_length, bucket_reso=(width, height)) # list of frames
581
-
582
  if len(video) < video_length:
583
  raise ValueError(f"Video length is less than {video_length}")
584
  video = np.stack(video, axis=0) # F, H, W, C
@@ -682,16 +673,50 @@ def main():
682
  logger.info("Merged model saved")
683
  return
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  if blocks_to_swap > 0:
686
- logger.info(f"Casting model to {dit_weight_dtype}")
687
- transformer.to(dtype=dit_weight_dtype)
688
  logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
689
  transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
690
  transformer.move_to_device_except_swap_blocks(device)
691
  transformer.prepare_block_swap_before_forward()
692
  else:
693
- logger.info(f"Moving and casting model to {device} and {dit_weight_dtype}")
694
- transformer.to(device=device, dtype=dit_weight_dtype)
695
  if args.img_in_txt_in_offloading:
696
  logger.info("Enable offloading img_in and txt_in to CPU")
697
  transformer.enable_img_in_txt_in_offloading()
 
25
  from hunyuan_model.text_encoder import PROMPT_TEMPLATE
26
  from hunyuan_model.vae import load_vae
27
  from hunyuan_model.models import load_transformer, get_rotary_pos_embed
28
+ from hunyuan_model.fp8_optimization import convert_fp8_linear
29
  from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
30
  from networks import lora
31
 
 
314
  # endregion
315
 
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  def prepare_vae(args, device):
318
  vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
319
  vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
 
463
  parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
464
  parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
465
  parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
466
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)")
467
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
468
+ parser.add_argument(
469
+ "--compile_args",
470
+ nargs=4,
471
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
472
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
473
+ help="Torch.compile settings",
474
+ )
475
 
476
  args = parser.parse_args()
477
 
 
481
 
482
  # update dit_weight based on model_base if not exists
483
 
484
+ if args.fp8_fast and not args.fp8:
485
+ raise ValueError("--fp8_fast requires --fp8")
486
+
487
  return args
488
 
489
 
 
569
  if args.video_path is not None:
570
  # v2v inference
571
  logger.info(f"Video2Video inference: {args.video_path}")
572
+ video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
 
 
 
 
 
573
  if len(video) < video_length:
574
  raise ValueError(f"Video length is less than {video_length}")
575
  video = np.stack(video, axis=0) # F, H, W, C
 
673
  logger.info("Merged model saved")
674
  return
675
 
676
+ logger.info(f"Casting model to {dit_weight_dtype}")
677
+ transformer.to(dtype=dit_weight_dtype)
678
+
679
+ if args.fp8_fast:
680
+ logger.info("Enabling FP8 acceleration")
681
+ params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"}
682
+ for name, param in transformer.named_parameters():
683
+ dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype
684
+ param.to(dtype=dtype_to_use)
685
+ convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep)
686
+
687
+ if args.compile:
688
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
689
+ logger.info(
690
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
691
+ )
692
+ torch._dynamo.config.cache_size_limit = 32
693
+ for i, block in enumerate(transformer.single_blocks):
694
+ compiled_block = torch.compile(
695
+ block,
696
+ backend=compile_backend,
697
+ mode=compile_mode,
698
+ dynamic=compile_dynamic.lower() in "true",
699
+ fullgraph=compile_fullgraph.lower() in "true",
700
+ )
701
+ transformer.single_blocks[i] = compiled_block
702
+ for i, block in enumerate(transformer.double_blocks):
703
+ compiled_block = torch.compile(
704
+ block,
705
+ backend=compile_backend,
706
+ mode=compile_mode,
707
+ dynamic=compile_dynamic.lower() in "true",
708
+ fullgraph=compile_fullgraph.lower() in "true",
709
+ )
710
+ transformer.double_blocks[i] = compiled_block
711
+
712
  if blocks_to_swap > 0:
 
 
713
  logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
714
  transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
715
  transformer.move_to_device_except_swap_blocks(device)
716
  transformer.prepare_block_swap_before_forward()
717
  else:
718
+ logger.info(f"Moving model to {device}")
719
+ transformer.to(device=device)
720
  if args.img_in_txt_in_offloading:
721
  logger.info("Enable offloading img_in and txt_in to CPU")
722
  transformer.enable_img_in_txt_in_offloading()
hv_train_network.py CHANGED
@@ -24,7 +24,7 @@ import toml
24
 
25
  import torch
26
  from tqdm import tqdm
27
- from accelerate.utils import set_seed
28
  from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
29
  from safetensors.torch import load_file
30
  import transformers
@@ -159,11 +159,21 @@ def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
159
  ]
160
  kwargs_handlers = [i for i in kwargs_handlers if i is not None]
161
 
 
 
 
 
 
 
 
 
 
162
  accelerator = Accelerator(
163
  gradient_accumulation_steps=args.gradient_accumulation_steps,
164
  mixed_precision=args.mixed_precision,
165
  log_with=log_with,
166
  project_dir=logging_dir,
 
167
  kwargs_handlers=kwargs_handlers,
168
  )
169
  print("accelerator device:", accelerator.device)
@@ -228,6 +238,25 @@ def line_to_prompt_dict(line: str) -> dict:
228
  prompt_dict["image_path"] = m.group(1)
229
  continue
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  except ValueError as ex:
232
  logger.error(f"Exception in parsing / 解析エラー: {parg}")
233
  logger.error(ex)
@@ -340,8 +369,7 @@ def should_sample_images(args, steps, epoch=None):
340
 
341
  class NetworkTrainer:
342
  def __init__(self):
343
- self._i2v_training = False
344
- self.pos_embed_cache = {}
345
 
346
  # TODO 他のスクリプトと共通化する
347
  def generate_step_logs(
@@ -872,7 +900,7 @@ class NetworkTrainer:
872
  transformer.switch_block_swap_for_inference()
873
 
874
  # Create a directory to save the samples
875
- save_dir = args.output_dir + "/sample"
876
  os.makedirs(save_dir, exist_ok=True)
877
 
878
  # save random state to restore later
@@ -919,13 +947,15 @@ class NetworkTrainer:
919
  width = sample_parameter.get("width", 256) # make smaller for faster and memory saving inference
920
  height = sample_parameter.get("height", 256)
921
  frame_count = sample_parameter.get("frame_count", 1)
922
- guidance_scale = sample_parameter.get("guidance_scale", 6.0)
923
  discrete_flow_shift = sample_parameter.get("discrete_flow_shift", 14.5)
924
  seed = sample_parameter.get("seed")
925
  prompt: str = sample_parameter.get("prompt", "")
926
  cfg_scale = sample_parameter.get("cfg_scale", None) # None for architecture default
927
  negative_prompt = sample_parameter.get("negative_prompt", None)
928
 
 
 
929
  if self.i2v_training:
930
  image_path = sample_parameter.get("image_path", None)
931
  if image_path is None:
@@ -934,6 +964,16 @@ class NetworkTrainer:
934
  else:
935
  image_path = None
936
 
 
 
 
 
 
 
 
 
 
 
937
  device = accelerator.device
938
  if seed is not None:
939
  torch.manual_seed(seed)
@@ -963,6 +1003,8 @@ class NetworkTrainer:
963
 
964
  if self.i2v_training:
965
  logger.info(f"image path: {image_path}")
 
 
966
 
967
  # inference: architecture dependent
968
  video = self.do_inference(
@@ -982,9 +1024,14 @@ class NetworkTrainer:
982
  guidance_scale,
983
  cfg_scale,
984
  image_path=image_path,
 
985
  )
986
 
987
  # Save video
 
 
 
 
988
  ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
989
  num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
990
  seed_suffix = "" if seed is None else f"_{seed}"
@@ -1011,15 +1058,25 @@ class NetworkTrainer:
1011
  def architecture_full_name(self) -> str:
1012
  return ARCHITECTURE_HUNYUAN_VIDEO_FULL
1013
 
1014
- def assert_model_specific_args(self, args: argparse.Namespace):
 
 
1015
  self._i2v_training = args.dit_in_channels == 32 # may be changed in the future
1016
  if self._i2v_training:
1017
  logger.info("I2V training mode")
1018
 
 
 
 
 
1019
  @property
1020
  def i2v_training(self) -> bool:
1021
  return self._i2v_training
1022
 
 
 
 
 
1023
  def process_sample_prompts(
1024
  self,
1025
  args: argparse.Namespace,
@@ -1108,6 +1165,7 @@ class NetworkTrainer:
1108
  guidance_scale,
1109
  cfg_scale,
1110
  image_path=None,
 
1111
  ):
1112
  """architecture dependent inference"""
1113
  device = accelerator.device
@@ -1260,12 +1318,13 @@ class NetworkTrainer:
1260
 
1261
  def load_transformer(
1262
  self,
 
1263
  args: argparse.Namespace,
1264
  dit_path: str,
1265
  attn_mode: str,
1266
  split_attn: bool,
1267
  loading_device: str,
1268
- dit_weight_dtype: torch.dtype,
1269
  ):
1270
  transformer = load_transformer(dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.dit_in_channels)
1271
 
@@ -1346,9 +1405,16 @@ class NetworkTrainer:
1346
  raise ValueError("dataset_config is required / dataset_configが必要です")
1347
  if args.dit is None:
1348
  raise ValueError("path to DiT model is required / DiTモデルのパスが必要です")
 
 
 
 
 
 
 
1349
 
1350
  # check model specific arguments
1351
- self.assert_model_specific_args(args)
1352
 
1353
  # show timesteps for debugging
1354
  if args.show_timesteps:
@@ -1389,7 +1455,7 @@ class NetworkTrainer:
1389
 
1390
  # HunyuanVideo: bfloat16 or float16, Wan2.1: bfloat16
1391
  dit_dtype = torch.bfloat16 if args.dit_dtype is None else model_utils.str_to_dtype(args.dit_dtype)
1392
- dit_weight_dtype = torch.float8_e4m3fn if args.fp8_base else dit_dtype
1393
  logger.info(f"DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
1394
 
1395
  # get embedding for sampling images
@@ -1406,6 +1472,7 @@ class NetworkTrainer:
1406
 
1407
  # load DiT model
1408
  blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
 
1409
  loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
1410
 
1411
  logger.info(f"Loading DiT model from {args.dit}")
@@ -1423,7 +1490,9 @@ class NetworkTrainer:
1423
  raise ValueError(
1424
  f"either --sdpa, --flash-attn, --flash3, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --flash3, --sage-attn, --xformersのいずれかを指定してください"
1425
  )
1426
- transformer = self.load_transformer(args, args.dit, attn_mode, args.split_attn, loading_device, dit_weight_dtype)
 
 
1427
  transformer.eval()
1428
  transformer.requires_grad_(False)
1429
 
@@ -1565,7 +1634,7 @@ class NetworkTrainer:
1565
  network_dtype = weight_dtype
1566
  network.to(network_dtype)
1567
 
1568
- if dit_weight_dtype != dit_dtype:
1569
  logger.info(f"casting model to {dit_weight_dtype}")
1570
  transformer.to(dit_weight_dtype)
1571
 
@@ -2239,6 +2308,34 @@ def setup_parser_common() -> argparse.ArgumentParser:
2239
  # parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
2240
  # parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
2241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2242
  parser.add_argument(
2243
  "--blocks_to_swap",
2244
  type=int,
@@ -2590,5 +2687,7 @@ if __name__ == "__main__":
2590
  args = parser.parse_args()
2591
  args = read_config_from_file(args, parser)
2592
 
 
 
2593
  trainer = NetworkTrainer()
2594
  trainer.train(args)
 
24
 
25
  import torch
26
  from tqdm import tqdm
27
+ from accelerate.utils import TorchDynamoPlugin, set_seed, DynamoBackend
28
  from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
29
  from safetensors.torch import load_file
30
  import transformers
 
159
  ]
160
  kwargs_handlers = [i for i in kwargs_handlers if i is not None]
161
 
162
+ dynamo_plugin = None
163
+ if args.dynamo_backend.upper() != "NO":
164
+ dynamo_plugin = TorchDynamoPlugin(
165
+ backend=DynamoBackend(args.dynamo_backend.upper()),
166
+ mode=args.dynamo_mode,
167
+ fullgraph=args.dynamo_fullgraph,
168
+ dynamic=args.dynamo_dynamic,
169
+ )
170
+
171
  accelerator = Accelerator(
172
  gradient_accumulation_steps=args.gradient_accumulation_steps,
173
  mixed_precision=args.mixed_precision,
174
  log_with=log_with,
175
  project_dir=logging_dir,
176
+ dynamo_plugin=dynamo_plugin,
177
  kwargs_handlers=kwargs_handlers,
178
  )
179
  print("accelerator device:", accelerator.device)
 
238
  prompt_dict["image_path"] = m.group(1)
239
  continue
240
 
241
+ m = re.match(r"cn (.+)", parg, re.IGNORECASE)
242
+ if m:
243
+ prompt_dict["control_video_path"] = m.group(1)
244
+ continue
245
+
246
+ m = re.match(r"ci (.+)", parg, re.IGNORECASE)
247
+ if m:
248
+ # can be multiple control images
249
+ control_image_path = m.group(1)
250
+ if "control_image_path" not in prompt_dict:
251
+ prompt_dict["control_image_path"] = []
252
+ prompt_dict["control_image_path"].append(control_image_path)
253
+ continue
254
+
255
+ m = re.match(r"of (.+)", parg, re.IGNORECASE)
256
+ if m: # output folder
257
+ prompt_dict["one_frame"] = m.group(1)
258
+ continue
259
+
260
  except ValueError as ex:
261
  logger.error(f"Exception in parsing / 解析エラー: {parg}")
262
  logger.error(ex)
 
369
 
370
  class NetworkTrainer:
371
  def __init__(self):
372
+ self.blocks_to_swap = None
 
373
 
374
  # TODO 他のスクリプトと共通化する
375
  def generate_step_logs(
 
900
  transformer.switch_block_swap_for_inference()
901
 
902
  # Create a directory to save the samples
903
+ save_dir = os.path.join(args.output_dir, "sample")
904
  os.makedirs(save_dir, exist_ok=True)
905
 
906
  # save random state to restore later
 
947
  width = sample_parameter.get("width", 256) # make smaller for faster and memory saving inference
948
  height = sample_parameter.get("height", 256)
949
  frame_count = sample_parameter.get("frame_count", 1)
950
+ guidance_scale = sample_parameter.get("guidance_scale", self.default_guidance_scale)
951
  discrete_flow_shift = sample_parameter.get("discrete_flow_shift", 14.5)
952
  seed = sample_parameter.get("seed")
953
  prompt: str = sample_parameter.get("prompt", "")
954
  cfg_scale = sample_parameter.get("cfg_scale", None) # None for architecture default
955
  negative_prompt = sample_parameter.get("negative_prompt", None)
956
 
957
+ frame_count = (frame_count - 1) // 4 * 4 + 1 # 1, 5, 9, 13, ... For HunyuanVideo and Wan2.1
958
+
959
  if self.i2v_training:
960
  image_path = sample_parameter.get("image_path", None)
961
  if image_path is None:
 
964
  else:
965
  image_path = None
966
 
967
+ if self.control_training:
968
+ control_video_path = sample_parameter.get("control_video_path", None)
969
+ if control_video_path is None:
970
+ logger.error(
971
+ "No control_video_path for control model / controlモデルのサンプル画像生成にはcontrol_video_pathが必要です"
972
+ )
973
+ return
974
+ else:
975
+ control_video_path = None
976
+
977
  device = accelerator.device
978
  if seed is not None:
979
  torch.manual_seed(seed)
 
1003
 
1004
  if self.i2v_training:
1005
  logger.info(f"image path: {image_path}")
1006
+ if self.control_training:
1007
+ logger.info(f"control video path: {control_video_path}")
1008
 
1009
  # inference: architecture dependent
1010
  video = self.do_inference(
 
1024
  guidance_scale,
1025
  cfg_scale,
1026
  image_path=image_path,
1027
+ control_video_path=control_video_path,
1028
  )
1029
 
1030
  # Save video
1031
+ if video is None:
1032
+ logger.error("No video generated / 生成された動画がありません")
1033
+ return
1034
+
1035
  ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
1036
  num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
1037
  seed_suffix = "" if seed is None else f"_{seed}"
 
1058
  def architecture_full_name(self) -> str:
1059
  return ARCHITECTURE_HUNYUAN_VIDEO_FULL
1060
 
1061
+ def handle_model_specific_args(self, args: argparse.Namespace):
1062
+ self.pos_embed_cache = {}
1063
+
1064
  self._i2v_training = args.dit_in_channels == 32 # may be changed in the future
1065
  if self._i2v_training:
1066
  logger.info("I2V training mode")
1067
 
1068
+ self._control_training = False # HunyuanVideo does not support control training yet
1069
+
1070
+ self.default_guidance_scale = 6.0
1071
+
1072
  @property
1073
  def i2v_training(self) -> bool:
1074
  return self._i2v_training
1075
 
1076
+ @property
1077
+ def control_training(self) -> bool:
1078
+ return self._control_training
1079
+
1080
  def process_sample_prompts(
1081
  self,
1082
  args: argparse.Namespace,
 
1165
  guidance_scale,
1166
  cfg_scale,
1167
  image_path=None,
1168
+ control_video_path=None,
1169
  ):
1170
  """architecture dependent inference"""
1171
  device = accelerator.device
 
1318
 
1319
  def load_transformer(
1320
  self,
1321
+ accelerator: Accelerator,
1322
  args: argparse.Namespace,
1323
  dit_path: str,
1324
  attn_mode: str,
1325
  split_attn: bool,
1326
  loading_device: str,
1327
+ dit_weight_dtype: Optional[torch.dtype],
1328
  ):
1329
  transformer = load_transformer(dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.dit_in_channels)
1330
 
 
1405
  raise ValueError("dataset_config is required / dataset_configが必要です")
1406
  if args.dit is None:
1407
  raise ValueError("path to DiT model is required / DiTモデルのパスが必要です")
1408
+ assert not args.fp8_scaled or args.fp8_base, "fp8_scaled requires fp8_base / fp8_scaledはfp8_baseが必要です"
1409
+
1410
+ if args.sage_attn:
1411
+ raise ValueError(
1412
+ "SageAttention doesn't support training currently. Please use `--sdpa` or `--xformers` etc. instead."
1413
+ " / SageAttentionは現在学習をサポートしていないようです。`--sdpa`や`--xformers`などの他のオプションを使ってください"
1414
+ )
1415
 
1416
  # check model specific arguments
1417
+ self.handle_model_specific_args(args)
1418
 
1419
  # show timesteps for debugging
1420
  if args.show_timesteps:
 
1455
 
1456
  # HunyuanVideo: bfloat16 or float16, Wan2.1: bfloat16
1457
  dit_dtype = torch.bfloat16 if args.dit_dtype is None else model_utils.str_to_dtype(args.dit_dtype)
1458
+ dit_weight_dtype = (None if args.fp8_scaled else torch.float8_e4m3fn) if args.fp8_base else dit_dtype
1459
  logger.info(f"DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
1460
 
1461
  # get embedding for sampling images
 
1472
 
1473
  # load DiT model
1474
  blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
1475
+ self.blocks_to_swap = blocks_to_swap
1476
  loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
1477
 
1478
  logger.info(f"Loading DiT model from {args.dit}")
 
1490
  raise ValueError(
1491
  f"either --sdpa, --flash-attn, --flash3, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --flash3, --sage-attn, --xformersのいずれかを指定してください"
1492
  )
1493
+ transformer = self.load_transformer(
1494
+ accelerator, args, args.dit, attn_mode, args.split_attn, loading_device, dit_weight_dtype
1495
+ )
1496
  transformer.eval()
1497
  transformer.requires_grad_(False)
1498
 
 
1634
  network_dtype = weight_dtype
1635
  network.to(network_dtype)
1636
 
1637
+ if dit_weight_dtype != dit_dtype and dit_weight_dtype is not None:
1638
  logger.info(f"casting model to {dit_weight_dtype}")
1639
  transformer.to(dit_weight_dtype)
1640
 
 
2308
  # parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
2309
  # parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
2310
 
2311
+ parser.add_argument(
2312
+ "--dynamo_backend",
2313
+ type=str,
2314
+ default="NO",
2315
+ choices=[e.value for e in DynamoBackend],
2316
+ help="dynamo backend type (default is None) / dynamoのbackendの種類(デフォルトは None)",
2317
+ )
2318
+
2319
+ parser.add_argument(
2320
+ "--dynamo_mode",
2321
+ type=str,
2322
+ default=None,
2323
+ choices=["default", "reduce-overhead", "max-autotune"],
2324
+ help="dynamo mode (default is default) / dynamoのモード(デフォルトは default)",
2325
+ )
2326
+
2327
+ parser.add_argument(
2328
+ "--dynamo_fullgraph",
2329
+ action="store_true",
2330
+ help="use fullgraph mode for dynamo / dynamoのfullgraphモードを使う",
2331
+ )
2332
+
2333
+ parser.add_argument(
2334
+ "--dynamo_dynamic",
2335
+ action="store_true",
2336
+ help="use dynamic mode for dynamo / dynamoのdynamicモードを使う",
2337
+ )
2338
+
2339
  parser.add_argument(
2340
  "--blocks_to_swap",
2341
  type=int,
 
2687
  args = parser.parse_args()
2688
  args = read_config_from_file(args, parser)
2689
 
2690
+ args.fp8_scaled = False # HunyuanVideo does not support this yet
2691
+
2692
  trainer = NetworkTrainer()
2693
  trainer.train(args)
merge_lora.py CHANGED
@@ -45,7 +45,7 @@ def main():
45
 
46
  logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
47
  weights_sd = load_file(lora_weight)
48
- network = lora.create_network_from_weights_hunyuan_video(
49
  lora_multiplier, weights_sd, unet=transformer, for_inference=True
50
  )
51
  logger.info("Merging LoRA weights to DiT model")
 
45
 
46
  logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
47
  weights_sd = load_file(lora_weight)
48
+ network = lora.create_arch_network_from_weights(
49
  lora_multiplier, weights_sd, unet=transformer, for_inference=True
50
  )
51
  logger.info("Merging LoRA weights to DiT model")
modules/fp8_optimization_utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import logging
6
+
7
+ from tqdm import tqdm
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+ from utils.device_utils import clean_memory_on_device
13
+
14
+
15
+ def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
16
+ """
17
+ Calculate the maximum representable value in FP8 format.
18
+ Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign).
19
+
20
+ Args:
21
+ exp_bits (int): Number of exponent bits
22
+ mantissa_bits (int): Number of mantissa bits
23
+ sign_bits (int): Number of sign bits (0 or 1)
24
+
25
+ Returns:
26
+ float: Maximum value representable in FP8 format
27
+ """
28
+ assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
29
+
30
+ # Calculate exponent bias
31
+ bias = 2 ** (exp_bits - 1) - 1
32
+
33
+ # Calculate maximum mantissa value
34
+ mantissa_max = 1.0
35
+ for i in range(mantissa_bits - 1):
36
+ mantissa_max += 2 ** -(i + 1)
37
+
38
+ # Calculate maximum value
39
+ max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
40
+
41
+ return max_value
42
+
43
+
44
+ def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None):
45
+ """
46
+ Quantize a tensor to FP8 format.
47
+
48
+ Args:
49
+ tensor (torch.Tensor): Tensor to quantize
50
+ scale (float or torch.Tensor): Scale factor
51
+ exp_bits (int): Number of exponent bits
52
+ mantissa_bits (int): Number of mantissa bits
53
+ sign_bits (int): Number of sign bits
54
+
55
+ Returns:
56
+ tuple: (quantized_tensor, scale_factor)
57
+ """
58
+ # Create scaled tensor
59
+ scaled_tensor = tensor / scale
60
+
61
+ # Calculate FP8 parameters
62
+ bias = 2 ** (exp_bits - 1) - 1
63
+
64
+ if max_value is None:
65
+ # Calculate max and min values
66
+ max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits)
67
+ min_value = -max_value if sign_bits > 0 else 0.0
68
+
69
+ # Clamp tensor to range
70
+ clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value)
71
+
72
+ # Quantization process
73
+ abs_values = torch.abs(clamped_tensor)
74
+ nonzero_mask = abs_values > 0
75
+
76
+ # Calculate log scales (only for non-zero elements)
77
+ log_scales = torch.zeros_like(clamped_tensor)
78
+ if nonzero_mask.any():
79
+ log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach()
80
+
81
+ # Limit log scales and calculate quantization factor
82
+ log_scales = torch.clamp(log_scales, min=1.0)
83
+ quant_factor = 2.0 ** (log_scales - mantissa_bits - bias)
84
+
85
+ # Quantize and dequantize
86
+ quantized = torch.round(clamped_tensor / quant_factor) * quant_factor
87
+
88
+ return quantized, scale
89
+
90
+
91
+ def optimize_state_dict_with_fp8(
92
+ state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False
93
+ ):
94
+ """
95
+ Optimize Linear layer weights in a model's state dict to FP8 format.
96
+
97
+ Args:
98
+ state_dict (dict): State dict to optimize, replaced in-place
99
+ calc_device (str): Device to quantize tensors on
100
+ target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
101
+ exclude_layer_keys (list, optional): Layer key patterns to exclude
102
+ exp_bits (int): Number of exponent bits
103
+ mantissa_bits (int): Number of mantissa bits
104
+ move_to_device (bool): Move optimized tensors to the calculating device
105
+
106
+ Returns:
107
+ dict: FP8 optimized state dict
108
+ """
109
+ if exp_bits == 4 and mantissa_bits == 3:
110
+ fp8_dtype = torch.float8_e4m3fn
111
+ elif exp_bits == 5 and mantissa_bits == 2:
112
+ fp8_dtype = torch.float8_e5m2
113
+ else:
114
+ raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
115
+
116
+ # Calculate FP8 max value
117
+ max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
118
+ min_value = -max_value # this function supports only signed FP8
119
+
120
+ # Create optimized state dict
121
+ optimized_count = 0
122
+
123
+ # Enumerate tarket keys
124
+ target_state_dict_keys = []
125
+ for key in state_dict.keys():
126
+ # Check if it's a weight key and matches target patterns
127
+ is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
128
+ is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
129
+ is_target = is_target and not is_excluded
130
+
131
+ if is_target and isinstance(state_dict[key], torch.Tensor):
132
+ target_state_dict_keys.append(key)
133
+
134
+ # Process each key
135
+ for key in tqdm(target_state_dict_keys):
136
+ value = state_dict[key]
137
+
138
+ # Save original device and dtype
139
+ original_device = value.device
140
+ original_dtype = value.dtype
141
+
142
+ # Move to calculation device
143
+ if calc_device is not None:
144
+ value = value.to(calc_device)
145
+
146
+ # Calculate scale factor
147
+ scale = torch.max(torch.abs(value.flatten())) / max_value
148
+ # print(f"Optimizing {key} with scale: {scale}")
149
+
150
+ # Quantize weight to FP8
151
+ quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
152
+
153
+ # Add to state dict using original key for weight and new key for scale
154
+ fp8_key = key # Maintain original key
155
+ scale_key = key.replace(".weight", ".scale_weight")
156
+
157
+ quantized_weight = quantized_weight.to(fp8_dtype)
158
+
159
+ if not move_to_device:
160
+ quantized_weight = quantized_weight.to(original_device)
161
+
162
+ scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
163
+
164
+ state_dict[fp8_key] = quantized_weight
165
+ state_dict[scale_key] = scale_tensor
166
+
167
+ optimized_count += 1
168
+
169
+ if calc_device is not None: # optimized_count % 10 == 0 and
170
+ # free memory on calculation device
171
+ clean_memory_on_device(calc_device)
172
+
173
+ logger.info(f"Number of optimized Linear layers: {optimized_count}")
174
+ return state_dict
175
+
176
+
177
+ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
178
+ """
179
+ Patched forward method for Linear layers with FP8 weights.
180
+
181
+ Args:
182
+ self: Linear layer instance
183
+ x (torch.Tensor): Input tensor
184
+ use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
185
+ max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
186
+
187
+ Returns:
188
+ torch.Tensor: Result of linear transformation
189
+ """
190
+ if use_scaled_mm:
191
+ input_dtype = x.dtype
192
+ original_weight_dtype = self.scale_weight.dtype
193
+ weight_dtype = self.weight.dtype
194
+ target_dtype = torch.float8_e5m2
195
+ assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported"
196
+ assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
197
+
198
+ if max_value is None:
199
+ # no input quantization
200
+ scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
201
+ else:
202
+ # calculate scale factor for input tensor
203
+ scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
204
+
205
+ # quantize input tensor to FP8: this seems to consume a lot of memory
206
+ x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value)
207
+
208
+ original_shape = x.shape
209
+ x = x.reshape(-1, x.shape[2]).to(target_dtype)
210
+
211
+ weight = self.weight.t()
212
+ scale_weight = self.scale_weight.to(torch.float32)
213
+
214
+ if self.bias is not None:
215
+ # float32 is not supported with bias in scaled_mm
216
+ o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
217
+ else:
218
+ o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
219
+
220
+ return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype)
221
+
222
+ else:
223
+ # Dequantize the weight
224
+ original_dtype = self.scale_weight.dtype
225
+ dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
226
+
227
+ # Perform linear transformation
228
+ if self.bias is not None:
229
+ output = F.linear(x, dequantized_weight, self.bias)
230
+ else:
231
+ output = F.linear(x, dequantized_weight)
232
+
233
+ return output
234
+
235
+
236
+ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
237
+ """
238
+ Apply monkey patching to a model using FP8 optimized state dict.
239
+
240
+ Args:
241
+ model (nn.Module): Model instance to patch
242
+ optimized_state_dict (dict): FP8 optimized state dict
243
+ use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
244
+
245
+ Returns:
246
+ nn.Module: The patched model (same instance, modified in-place)
247
+ """
248
+ # # Calculate FP8 float8_e5m2 max value
249
+ # max_value = calculate_fp8_maxval(5, 2)
250
+ max_value = None # do not quantize input tensor
251
+
252
+ # Find all scale keys to identify FP8-optimized layers
253
+ scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
254
+
255
+ # Enumerate patched layers
256
+ patched_module_paths = set()
257
+ for scale_key in scale_keys:
258
+ # Extract module path from scale key (remove .scale_weight)
259
+ module_path = scale_key.rsplit(".scale_weight", 1)[0]
260
+ patched_module_paths.add(module_path)
261
+
262
+ patched_count = 0
263
+
264
+ # Apply monkey patch to each layer with FP8 weights
265
+ for name, module in model.named_modules():
266
+ # Check if this module has a corresponding scale_weight
267
+ has_scale = name in patched_module_paths
268
+
269
+ # Apply patch if it's a Linear layer with FP8 scale
270
+ if isinstance(module, nn.Linear) and has_scale:
271
+ # register the scale_weight as a buffer to load the state_dict
272
+ module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
273
+
274
+ # Create a new forward method with the patched version.
275
+ def new_forward(self, x):
276
+ return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
277
+
278
+ # Bind method to module
279
+ module.forward = new_forward.__get__(module, type(module))
280
+
281
+ patched_count += 1
282
+
283
+ logger.info(f"Number of monkey-patched Linear layers: {patched_count}")
284
+ return model
285
+
286
+
287
+ # Example usage
288
+ def example_usage():
289
+ # Small test model
290
+ class TestModel(nn.Module):
291
+ def __init__(self):
292
+ super().__init__()
293
+ fc1 = nn.Linear(768, 3072)
294
+ act1 = nn.GELU()
295
+ fc2 = nn.Linear(3072, 768)
296
+ act2 = nn.GELU()
297
+ fc3 = nn.Linear(768, 768)
298
+
299
+ # Set layer names for testing
300
+ self.single_blocks = nn.ModuleList([fc1, act1, fc2, act2, fc3])
301
+
302
+ self.fc4 = nn.Linear(768, 128)
303
+
304
+ def forward(self, x):
305
+ for layer in self.single_blocks:
306
+ x = layer(x)
307
+ x = self.fc4(x)
308
+ return x
309
+
310
+ # Instantiate model
311
+ test_model = TestModel()
312
+ test_model.to(torch.float16) # convert to FP16 for testing
313
+
314
+ # Test input tensor
315
+ test_input = torch.randn(1, 768, dtype=torch.float16)
316
+
317
+ # Calculate output before optimization
318
+ with torch.no_grad():
319
+ original_output = test_model(test_input)
320
+ print("original output", original_output[0, :5])
321
+
322
+ # Get state dict
323
+ state_dict = test_model.state_dict()
324
+
325
+ # Apply FP8 optimization to state dict
326
+ cuda_device = torch.device("cuda")
327
+ optimized_state_dict = optimize_state_dict_with_fp8(state_dict, cuda_device, ["single_blocks"], ["2"])
328
+
329
+ # Apply monkey patching to the model
330
+ optimized_model = TestModel() # re-instantiate model
331
+ optimized_model.to(torch.float16) # convert to FP16 for testing
332
+ apply_fp8_monkey_patch(optimized_model, optimized_state_dict)
333
+
334
+ # Load optimized state dict
335
+ optimized_model.load_state_dict(optimized_state_dict, strict=True, assign=True) # assign=True to load buffer
336
+
337
+ # Calculate output after optimization
338
+ with torch.no_grad():
339
+ optimized_output = optimized_model(test_input)
340
+ print("optimized output", optimized_output[0, :5])
341
+
342
+ # Compare accuracy
343
+ error = torch.mean(torch.abs(original_output - optimized_output))
344
+ print(f"Mean absolute error: {error.item()}")
345
+
346
+ # Check memory usage
347
+ original_params = sum(p.nelement() * p.element_size() for p in test_model.parameters()) / (1024 * 1024)
348
+ print(f"Model parameter memory: {original_params:.2f} MB")
349
+ optimized_params = sum(p.nelement() * p.element_size() for p in optimized_model.parameters()) / (1024 * 1024)
350
+ print(f"Optimized model parameter memory: {optimized_params:.2f} MB")
351
+
352
+ return test_model
353
+
354
+
355
+ if __name__ == "__main__":
356
+ example_usage()
networks/lora.py CHANGED
@@ -8,7 +8,6 @@ import math
8
  import os
9
  import re
10
  from typing import Dict, List, Optional, Type, Union
11
- from diffusers import AutoencoderKL
12
  from transformers import CLIPTextModel
13
  import numpy as np
14
  import torch
 
8
  import os
9
  import re
10
  from typing import Dict, List, Optional, Type, Union
 
11
  from transformers import CLIPTextModel
12
  import numpy as np
13
  import torch
networks/lora_framepack.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA module for FramePack
2
+
3
+ import ast
4
+ from typing import Dict, List, Optional
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ import networks.lora as lora
14
+
15
+
16
+ FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"]
17
+
18
+
19
+ def create_arch_network(
20
+ multiplier: float,
21
+ network_dim: Optional[int],
22
+ network_alpha: Optional[float],
23
+ vae: nn.Module,
24
+ text_encoders: List[nn.Module],
25
+ unet: nn.Module,
26
+ neuron_dropout: Optional[float] = None,
27
+ **kwargs,
28
+ ):
29
+ # add default exclude patterns
30
+ exclude_patterns = kwargs.get("exclude_patterns", None)
31
+ if exclude_patterns is None:
32
+ exclude_patterns = []
33
+ else:
34
+ exclude_patterns = ast.literal_eval(exclude_patterns)
35
+
36
+ # exclude if 'norm' in the name of the module
37
+ exclude_patterns.append(r".*(norm).*")
38
+
39
+ kwargs["exclude_patterns"] = exclude_patterns
40
+
41
+ return lora.create_network(
42
+ FRAMEPACK_TARGET_REPLACE_MODULES,
43
+ "lora_unet",
44
+ multiplier,
45
+ network_dim,
46
+ network_alpha,
47
+ vae,
48
+ text_encoders,
49
+ unet,
50
+ neuron_dropout=neuron_dropout,
51
+ **kwargs,
52
+ )
53
+
54
+
55
+ def create_arch_network_from_weights(
56
+ multiplier: float,
57
+ weights_sd: Dict[str, torch.Tensor],
58
+ text_encoders: Optional[List[nn.Module]] = None,
59
+ unet: Optional[nn.Module] = None,
60
+ for_inference: bool = False,
61
+ **kwargs,
62
+ ) -> lora.LoRANetwork:
63
+ return lora.create_network_from_weights(
64
+ FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
65
+ )
pyproject.toml CHANGED
@@ -5,13 +5,15 @@ description = "Musubi Tuner by kohya_ss"
5
  readme = "README.md"
6
  requires-python = ">=3.10, <3.11"
7
  dependencies = [
8
- "accelerate>=1.0.0",
9
  "ascii-magic==2.3.0",
10
  "av==14.0.1",
11
  "bitsandbytes>=0.45.0",
12
  "diffusers>=0.32.1",
 
13
  "einops>=0.7.0",
14
- "huggingface-hub>=0.26.5",
 
15
  "matplotlib>=3.10.0",
16
  "opencv-python>=4.10.0.84",
17
  "pillow>=10.2.0",
 
5
  readme = "README.md"
6
  requires-python = ">=3.10, <3.11"
7
  dependencies = [
8
+ "accelerate>=1.6.0",
9
  "ascii-magic==2.3.0",
10
  "av==14.0.1",
11
  "bitsandbytes>=0.45.0",
12
  "diffusers>=0.32.1",
13
+ "easydict==1.13",
14
  "einops>=0.7.0",
15
+ "ftfy==6.3.1",
16
+ "huggingface-hub>=0.30.0",
17
  "matplotlib>=3.10.0",
18
  "opencv-python>=4.10.0.84",
19
  "pillow>=10.2.0",
requirements.txt CHANGED
@@ -1,11 +1,11 @@
1
- accelerate==1.2.1
2
  av==14.0.1
3
- bitsandbytes==0.45.0
4
  diffusers==0.32.1
5
  einops==0.7.0
6
- huggingface-hub==0.26.5
7
  opencv-python==4.10.0.84
8
- pillow==10.2.0
9
  safetensors==0.4.5
10
  toml==0.10.2
11
  tqdm==4.67.1
 
1
+ accelerate==1.6.0
2
  av==14.0.1
3
+ bitsandbytes==0.45.4
4
  diffusers==0.32.1
5
  einops==0.7.0
6
+ huggingface-hub==0.30.0
7
  opencv-python==4.10.0.84
8
+ pillow
9
  safetensors==0.4.5
10
  toml==0.10.2
11
  tqdm==4.67.1
utils/safetensors_utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import json
3
  import struct
@@ -169,7 +171,7 @@ class MemoryEfficientSafeOpen:
169
 
170
 
171
  def load_safetensors(
172
- path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
173
  ) -> dict[str, torch.Tensor]:
174
  if disable_mmap:
175
  # return safetensors.torch.load(open(path, "rb").read())
@@ -189,3 +191,31 @@ def load_safetensors(
189
  for key in state_dict.keys():
190
  state_dict[key] = state_dict[key].to(dtype=dtype)
191
  return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
  import torch
4
  import json
5
  import struct
 
171
 
172
 
173
  def load_safetensors(
174
+ path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
175
  ) -> dict[str, torch.Tensor]:
176
  if disable_mmap:
177
  # return safetensors.torch.load(open(path, "rb").read())
 
191
  for key in state_dict.keys():
192
  state_dict[key] = state_dict[key].to(dtype=dtype)
193
  return state_dict
194
+
195
+
196
+ def load_split_weights(
197
+ file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False
198
+ ) -> Dict[str, torch.Tensor]:
199
+ """
200
+ Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix.
201
+ dtype is as is, no conversion is done.
202
+ """
203
+ device = torch.device(device)
204
+
205
+ # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
206
+ basename = os.path.basename(file_path)
207
+ match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
208
+ if match:
209
+ prefix = basename[: match.start(2)]
210
+ count = int(match.group(3))
211
+ state_dict = {}
212
+ for i in range(count):
213
+ filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
214
+ filepath = os.path.join(os.path.dirname(file_path), filename)
215
+ if os.path.exists(filepath):
216
+ state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap))
217
+ else:
218
+ raise FileNotFoundError(f"File {filepath} not found")
219
+ else:
220
+ state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap)
221
+ return state_dict
utils/sai_model_spec.py CHANGED
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
7
  import safetensors
8
  import logging
9
 
10
- from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, ARCHITECTURE_WAN
11
 
12
  logger = logging.getLogger(__name__)
13
  logger.setLevel(logging.INFO)
@@ -59,9 +59,13 @@ ARCH_HUNYUAN_VIDEO = "hunyuan-video"
59
  # Official Wan2.1 weights does not have sai_model_spec, so we use this as an architecture name
60
  ARCH_WAN = "wan2.1"
61
 
 
 
62
  ADAPTER_LORA = "lora"
63
 
64
  IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo"
 
 
65
 
66
  PRED_TYPE_EPSILON = "epsilon"
67
  # PRED_TYPE_V = "v"
@@ -121,8 +125,13 @@ def build_metadata(
121
  # arch = ARCH_HUNYUAN_VIDEO
122
  if architecture == ARCHITECTURE_HUNYUAN_VIDEO:
123
  arch = ARCH_HUNYUAN_VIDEO
 
124
  elif architecture == ARCHITECTURE_WAN:
125
  arch = ARCH_WAN
 
 
 
 
126
  else:
127
  raise ValueError(f"Unknown architecture: {architecture}")
128
 
@@ -130,7 +139,6 @@ def build_metadata(
130
  arch += f"/{ADAPTER_LORA}"
131
  metadata["modelspec.architecture"] = arch
132
 
133
- impl = IMPL_HUNYUAN_VIDEO
134
  metadata["modelspec.implementation"] = impl
135
 
136
  if title is None:
 
7
  import safetensors
8
  import logging
9
 
10
+ from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, ARCHITECTURE_WAN, ARCHITECTURE_FRAMEPACK
11
 
12
  logger = logging.getLogger(__name__)
13
  logger.setLevel(logging.INFO)
 
59
  # Official Wan2.1 weights does not have sai_model_spec, so we use this as an architecture name
60
  ARCH_WAN = "wan2.1"
61
 
62
+ ARCH_FRAMEPACK = "framepack"
63
+
64
  ADAPTER_LORA = "lora"
65
 
66
  IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo"
67
+ IMPL_WAN = "https://github.com/Wan-Video/Wan2.1"
68
+ IMPL_FRAMEPACK = "https://github.com/lllyasviel/FramePack"
69
 
70
  PRED_TYPE_EPSILON = "epsilon"
71
  # PRED_TYPE_V = "v"
 
125
  # arch = ARCH_HUNYUAN_VIDEO
126
  if architecture == ARCHITECTURE_HUNYUAN_VIDEO:
127
  arch = ARCH_HUNYUAN_VIDEO
128
+ impl = IMPL_HUNYUAN_VIDEO
129
  elif architecture == ARCHITECTURE_WAN:
130
  arch = ARCH_WAN
131
+ impl = IMPL_WAN
132
+ elif architecture == ARCHITECTURE_FRAMEPACK:
133
+ arch = ARCH_FRAMEPACK
134
+ impl = IMPL_FRAMEPACK
135
  else:
136
  raise ValueError(f"Unknown architecture: {architecture}")
137
 
 
139
  arch += f"/{ADAPTER_LORA}"
140
  metadata["modelspec.architecture"] = arch
141
 
 
142
  metadata["modelspec.implementation"] = impl
143
 
144
  if title is None:
utils/train_utils.py CHANGED
@@ -36,6 +36,7 @@ def get_sanitized_config_or_none(args: argparse.Namespace):
36
  "vae",
37
  "text_encoder1",
38
  "text_encoder2",
 
39
  "base_weights",
40
  "network_weights",
41
  "output_dir",
 
36
  "vae",
37
  "text_encoder1",
38
  "text_encoder2",
39
+ "image_encoder",
40
  "base_weights",
41
  "network_weights",
42
  "output_dir",
wan/__init__.py CHANGED
@@ -1,3 +1 @@
1
  # from . import configs, distributed, modules
2
- from .image2video import WanI2V
3
- from .text2video import WanT2V
 
1
  # from . import configs, distributed, modules
 
 
wan/configs/__init__.py CHANGED
@@ -1,8 +1,9 @@
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import copy
3
  import os
 
4
 
5
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
 
7
  from .wan_i2v_14B import i2v_14B
8
  from .wan_t2v_1_3B import t2v_1_3B
@@ -10,33 +11,59 @@ from .wan_t2v_14B import t2v_14B
10
 
11
  # the config of t2i_14B is the same as t2v_14B
12
  t2i_14B = copy.deepcopy(t2v_14B)
13
- t2i_14B.__name__ = 'Config: Wan T2I 14B'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  WAN_CONFIGS = {
16
- 't2v-14B': t2v_14B,
17
- 't2v-1.3B': t2v_1_3B,
18
- 'i2v-14B': i2v_14B,
19
- 't2i-14B': t2i_14B,
 
 
 
 
20
  }
21
 
22
  SIZE_CONFIGS = {
23
- '720*1280': (720, 1280),
24
- '1280*720': (1280, 720),
25
- '480*832': (480, 832),
26
- '832*480': (832, 480),
27
- '1024*1024': (1024, 1024),
28
  }
29
 
30
  MAX_AREA_CONFIGS = {
31
- '720*1280': 720 * 1280,
32
- '1280*720': 1280 * 720,
33
- '480*832': 480 * 832,
34
- '832*480': 832 * 480,
35
  }
36
 
37
  SUPPORTED_SIZES = {
38
- 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
39
- 't2v-1.3B': ('480*832', '832*480'),
40
- 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
41
- 't2i-14B': tuple(SIZE_CONFIGS.keys()),
 
 
 
 
42
  }
 
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import copy
3
  import os
4
+ import torch
5
 
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
 
8
  from .wan_i2v_14B import i2v_14B
9
  from .wan_t2v_1_3B import t2v_1_3B
 
11
 
12
  # the config of t2i_14B is the same as t2v_14B
13
  t2i_14B = copy.deepcopy(t2v_14B)
14
+ t2i_14B.__name__ = "Config: Wan T2I 14B"
15
+
16
+ # support Fun models: deepcopy and change some configs. FC denotes Fun Control
17
+ t2v_1_3B_FC = copy.deepcopy(t2v_1_3B)
18
+ t2v_1_3B_FC.__name__ = "Config: Wan-Fun-Control T2V 1.3B"
19
+ t2v_1_3B_FC.i2v = True # this is strange, but Fun-Control model needs this because it has img cross-attention
20
+ t2v_1_3B_FC.in_dim = 48
21
+ t2v_1_3B_FC.is_fun_control = True
22
+
23
+ t2v_14B_FC = copy.deepcopy(t2v_14B)
24
+ t2v_14B_FC.__name__ = "Config: Wan-Fun-Control T2V 14B"
25
+ t2v_14B_FC.i2v = True # this is strange, but Fun-Control model needs this because it has img cross-attention
26
+ t2v_14B_FC.in_dim = 48 # same as i2v_14B, use zeros for image latents
27
+ t2v_14B_FC.is_fun_control = True
28
+
29
+ i2v_14B_FC = copy.deepcopy(i2v_14B)
30
+ i2v_14B_FC.__name__ = "Config: Wan-Fun-Control I2V 14B"
31
+ i2v_14B_FC.in_dim = 48
32
+ i2v_14B_FC.is_fun_control = True
33
 
34
  WAN_CONFIGS = {
35
+ "t2v-14B": t2v_14B,
36
+ "t2v-1.3B": t2v_1_3B,
37
+ "i2v-14B": i2v_14B,
38
+ "t2i-14B": t2i_14B,
39
+ # Fun Control models
40
+ "t2v-1.3B-FC": t2v_1_3B_FC,
41
+ "t2v-14B-FC": t2v_14B_FC,
42
+ "i2v-14B-FC": i2v_14B_FC,
43
  }
44
 
45
  SIZE_CONFIGS = {
46
+ "720*1280": (720, 1280),
47
+ "1280*720": (1280, 720),
48
+ "480*832": (480, 832),
49
+ "832*480": (832, 480),
50
+ "1024*1024": (1024, 1024),
51
  }
52
 
53
  MAX_AREA_CONFIGS = {
54
+ "720*1280": 720 * 1280,
55
+ "1280*720": 1280 * 720,
56
+ "480*832": 480 * 832,
57
+ "832*480": 832 * 480,
58
  }
59
 
60
  SUPPORTED_SIZES = {
61
+ "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"),
62
+ "t2v-1.3B": ("480*832", "832*480"),
63
+ "i2v-14B": ("720*1280", "1280*720", "480*832", "832*480"),
64
+ "t2i-14B": tuple(SIZE_CONFIGS.keys()),
65
+ # Fun Control models
66
+ "t2v-1.3B-FC": ("480*832", "832*480"),
67
+ "t2v-14B-FC": ("720*1280", "1280*720", "480*832", "832*480"),
68
+ "i2v-14B-FC": ("720*1280", "1280*720", "480*832", "832*480"),
69
  }
wan/configs/shared_config.py CHANGED
@@ -12,6 +12,7 @@ wan_shared_cfg.text_len = 512
12
 
13
  # transformer
14
  wan_shared_cfg.param_dtype = torch.bfloat16
 
15
 
16
  # inference
17
  wan_shared_cfg.num_train_timesteps = 1000
 
12
 
13
  # transformer
14
  wan_shared_cfg.param_dtype = torch.bfloat16
15
+ wan_shared_cfg.out_dim = 16
16
 
17
  # inference
18
  wan_shared_cfg.num_train_timesteps = 1000
wan/configs/wan_i2v_14B.py CHANGED
@@ -4,22 +4,24 @@ from easydict import EasyDict
4
 
5
  from .shared_config import wan_shared_cfg
6
 
7
- #------------------------ Wan I2V 14B ------------------------#
8
 
9
- i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
10
  i2v_14B.update(wan_shared_cfg)
 
 
11
 
12
- i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
- i2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
 
15
  # clip
16
- i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
17
  i2v_14B.clip_dtype = torch.float16
18
- i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
19
- i2v_14B.clip_tokenizer = 'xlm-roberta-large'
20
 
21
  # vae
22
- i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
23
  i2v_14B.vae_stride = (4, 8, 8)
24
 
25
  # transformer
@@ -27,6 +29,7 @@ i2v_14B.patch_size = (1, 2, 2)
27
  i2v_14B.dim = 5120
28
  i2v_14B.ffn_dim = 13824
29
  i2v_14B.freq_dim = 256
 
30
  i2v_14B.num_heads = 40
31
  i2v_14B.num_layers = 40
32
  i2v_14B.window_size = (-1, -1)
 
4
 
5
  from .shared_config import wan_shared_cfg
6
 
7
+ # ------------------------ Wan I2V 14B ------------------------#
8
 
9
+ i2v_14B = EasyDict(__name__="Config: Wan I2V 14B")
10
  i2v_14B.update(wan_shared_cfg)
11
+ i2v_14B.i2v = True
12
+ i2v_14B.is_fun_control = False
13
 
14
+ i2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
15
+ i2v_14B.t5_tokenizer = "google/umt5-xxl"
16
 
17
  # clip
18
+ i2v_14B.clip_model = "clip_xlm_roberta_vit_h_14"
19
  i2v_14B.clip_dtype = torch.float16
20
+ i2v_14B.clip_checkpoint = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
21
+ i2v_14B.clip_tokenizer = "xlm-roberta-large"
22
 
23
  # vae
24
+ i2v_14B.vae_checkpoint = "Wan2.1_VAE.pth"
25
  i2v_14B.vae_stride = (4, 8, 8)
26
 
27
  # transformer
 
29
  i2v_14B.dim = 5120
30
  i2v_14B.ffn_dim = 13824
31
  i2v_14B.freq_dim = 256
32
+ i2v_14B.in_dim = 36
33
  i2v_14B.num_heads = 40
34
  i2v_14B.num_layers = 40
35
  i2v_14B.window_size = (-1, -1)
wan/configs/wan_t2v_14B.py CHANGED
@@ -3,17 +3,19 @@ from easydict import EasyDict
3
 
4
  from .shared_config import wan_shared_cfg
5
 
6
- #------------------------ Wan T2V 14B ------------------------#
7
 
8
- t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
9
  t2v_14B.update(wan_shared_cfg)
 
 
10
 
11
  # t5
12
- t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
- t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
 
15
  # vae
16
- t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
  t2v_14B.vae_stride = (4, 8, 8)
18
 
19
  # transformer
@@ -21,6 +23,7 @@ t2v_14B.patch_size = (1, 2, 2)
21
  t2v_14B.dim = 5120
22
  t2v_14B.ffn_dim = 13824
23
  t2v_14B.freq_dim = 256
 
24
  t2v_14B.num_heads = 40
25
  t2v_14B.num_layers = 40
26
  t2v_14B.window_size = (-1, -1)
 
3
 
4
  from .shared_config import wan_shared_cfg
5
 
6
+ # ------------------------ Wan T2V 14B ------------------------#
7
 
8
+ t2v_14B = EasyDict(__name__="Config: Wan T2V 14B")
9
  t2v_14B.update(wan_shared_cfg)
10
+ t2v_14B.i2v = False
11
+ t2v_14B.is_fun_control = False
12
 
13
  # t5
14
+ t2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
15
+ t2v_14B.t5_tokenizer = "google/umt5-xxl"
16
 
17
  # vae
18
+ t2v_14B.vae_checkpoint = "Wan2.1_VAE.pth"
19
  t2v_14B.vae_stride = (4, 8, 8)
20
 
21
  # transformer
 
23
  t2v_14B.dim = 5120
24
  t2v_14B.ffn_dim = 13824
25
  t2v_14B.freq_dim = 256
26
+ t2v_14B.in_dim = 16
27
  t2v_14B.num_heads = 40
28
  t2v_14B.num_layers = 40
29
  t2v_14B.window_size = (-1, -1)
wan/configs/wan_t2v_1_3B.py CHANGED
@@ -3,17 +3,19 @@ from easydict import EasyDict
3
 
4
  from .shared_config import wan_shared_cfg
5
 
6
- #------------------------ Wan T2V 1.3B ------------------------#
7
 
8
- t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
9
  t2v_1_3B.update(wan_shared_cfg)
 
 
10
 
11
  # t5
12
- t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
- t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
14
 
15
  # vae
16
- t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
  t2v_1_3B.vae_stride = (4, 8, 8)
18
 
19
  # transformer
@@ -21,6 +23,7 @@ t2v_1_3B.patch_size = (1, 2, 2)
21
  t2v_1_3B.dim = 1536
22
  t2v_1_3B.ffn_dim = 8960
23
  t2v_1_3B.freq_dim = 256
 
24
  t2v_1_3B.num_heads = 12
25
  t2v_1_3B.num_layers = 30
26
  t2v_1_3B.window_size = (-1, -1)
 
3
 
4
  from .shared_config import wan_shared_cfg
5
 
6
+ # ------------------------ Wan T2V 1.3B ------------------------#
7
 
8
+ t2v_1_3B = EasyDict(__name__="Config: Wan T2V 1.3B")
9
  t2v_1_3B.update(wan_shared_cfg)
10
+ t2v_1_3B.i2v = False
11
+ t2v_1_3B.is_fun_control = False
12
 
13
  # t5
14
+ t2v_1_3B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
15
+ t2v_1_3B.t5_tokenizer = "google/umt5-xxl"
16
 
17
  # vae
18
+ t2v_1_3B.vae_checkpoint = "Wan2.1_VAE.pth"
19
  t2v_1_3B.vae_stride = (4, 8, 8)
20
 
21
  # transformer
 
23
  t2v_1_3B.dim = 1536
24
  t2v_1_3B.ffn_dim = 8960
25
  t2v_1_3B.freq_dim = 256
26
+ t2v_1_3B.in_dim = 16
27
  t2v_1_3B.num_heads = 12
28
  t2v_1_3B.num_layers = 30
29
  t2v_1_3B.window_size = (-1, -1)
wan/modules/model.py CHANGED
@@ -1,13 +1,25 @@
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import math
 
3
 
4
  import torch
5
  import torch.nn as nn
6
  from torch.utils.checkpoint import checkpoint
 
 
 
 
 
 
 
 
 
 
7
 
8
  from .attention import flash_attention
9
  from utils.device_utils import clean_memory_on_device
10
  from modules.custom_offloading_utils import ModelOffloader
 
11
 
12
  __all__ = ["WanModel"]
13
 
@@ -602,6 +614,40 @@ class WanModel(nn.Module): # ModelMixin, ConfigMixin):
602
  def device(self):
603
  return next(self.parameters()).device
604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  def enable_gradient_checkpointing(self):
606
  self.gradient_checkpointing = True
607
 
@@ -661,7 +707,7 @@ class WanModel(nn.Module): # ModelMixin, ConfigMixin):
661
  return
662
  self.offloader.prepare_block_devices_before_forward(self.blocks)
663
 
664
- def forward(self, x, t, context, seq_len, clip_fea=None, y=None):
665
  r"""
666
  Forward pass through the diffusion model
667
 
@@ -683,8 +729,9 @@ class WanModel(nn.Module): # ModelMixin, ConfigMixin):
683
  List[Tensor]:
684
  List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
685
  """
686
- if self.model_type == "i2v":
687
- assert clip_fea is not None and y is not None
 
688
  # params
689
  device = self.patch_embedding.weight.device
690
  if self.freqs.device != device:
@@ -738,10 +785,13 @@ class WanModel(nn.Module): # ModelMixin, ConfigMixin):
738
 
739
  # print(f"x: {x.shape}, e: {e0.shape}, context: {context.shape}, seq_lens: {seq_lens}")
740
  for block_idx, block in enumerate(self.blocks):
741
- if self.blocks_to_swap:
 
 
742
  self.offloader.wait_for_block(block_idx)
743
 
744
- x = block(x, **kwargs)
 
745
 
746
  if self.blocks_to_swap:
747
  self.offloader.submit_move_blocks_forward(self.blocks, block_idx)
@@ -801,3 +851,83 @@ class WanModel(nn.Module): # ModelMixin, ConfigMixin):
801
 
802
  # init output layer
803
  nn.init.zeros_(self.head.head.weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import math
3
+ from typing import Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
  from torch.utils.checkpoint import checkpoint
8
+ from accelerate import init_empty_weights
9
+
10
+ import logging
11
+
12
+ from utils.safetensors_utils import MemoryEfficientSafeOpen, load_safetensors
13
+
14
+ logger = logging.getLogger(__name__)
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ from utils.device_utils import clean_memory_on_device
18
 
19
  from .attention import flash_attention
20
  from utils.device_utils import clean_memory_on_device
21
  from modules.custom_offloading_utils import ModelOffloader
22
+ from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8
23
 
24
  __all__ = ["WanModel"]
25
 
 
614
  def device(self):
615
  return next(self.parameters()).device
616
 
617
+ def fp8_optimization(
618
+ self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False
619
+ ) -> int:
620
+ """
621
+ Optimize the model state_dict with fp8.
622
+
623
+ Args:
624
+ state_dict (dict[str, torch.Tensor]):
625
+ The state_dict of the model.
626
+ device (torch.device):
627
+ The device to calculate the weight.
628
+ move_to_device (bool):
629
+ Whether to move the weight to the device after optimization.
630
+ """
631
+ TARGET_KEYS = ["blocks"]
632
+ EXCLUDE_KEYS = [
633
+ "norm",
634
+ "patch_embedding",
635
+ "text_embedding",
636
+ "time_embedding",
637
+ "time_projection",
638
+ "head",
639
+ "modulation",
640
+ "img_emb",
641
+ ]
642
+
643
+ # inplace optimization
644
+ state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device)
645
+
646
+ # apply monkey patching
647
+ apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm)
648
+
649
+ return state_dict
650
+
651
  def enable_gradient_checkpointing(self):
652
  self.gradient_checkpointing = True
653
 
 
707
  return
708
  self.offloader.prepare_block_devices_before_forward(self.blocks)
709
 
710
+ def forward(self, x, t, context, seq_len, clip_fea=None, y=None, skip_block_indices=None):
711
  r"""
712
  Forward pass through the diffusion model
713
 
 
729
  List[Tensor]:
730
  List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
731
  """
732
+ # remove assertions to work with Fun-Control T2V
733
+ # if self.model_type == "i2v":
734
+ # assert clip_fea is not None and y is not None
735
  # params
736
  device = self.patch_embedding.weight.device
737
  if self.freqs.device != device:
 
785
 
786
  # print(f"x: {x.shape}, e: {e0.shape}, context: {context.shape}, seq_lens: {seq_lens}")
787
  for block_idx, block in enumerate(self.blocks):
788
+ is_block_skipped = skip_block_indices is not None and block_idx in skip_block_indices
789
+
790
+ if self.blocks_to_swap and not is_block_skipped:
791
  self.offloader.wait_for_block(block_idx)
792
 
793
+ if not is_block_skipped:
794
+ x = block(x, **kwargs)
795
 
796
  if self.blocks_to_swap:
797
  self.offloader.submit_move_blocks_forward(self.blocks, block_idx)
 
851
 
852
  # init output layer
853
  nn.init.zeros_(self.head.head.weight)
854
+
855
+
856
+ def detect_wan_sd_dtype(path: str) -> torch.dtype:
857
+ # get dtype from model weights
858
+ with MemoryEfficientSafeOpen(path) as f:
859
+ keys = set(f.keys())
860
+ key1 = "model.diffusion_model.blocks.0.cross_attn.k.weight" # 1.3B
861
+ key2 = "blocks.0.cross_attn.k.weight" # 14B
862
+ if key1 in keys:
863
+ dit_dtype = f.get_tensor(key1).dtype
864
+ elif key2 in keys:
865
+ dit_dtype = f.get_tensor(key2).dtype
866
+ else:
867
+ raise ValueError(f"Could not find the dtype in the model weights: {path}")
868
+ logger.info(f"Detected DiT dtype: {dit_dtype}")
869
+ return dit_dtype
870
+
871
+
872
+ def load_wan_model(
873
+ config: any,
874
+ device: Union[str, torch.device],
875
+ dit_path: str,
876
+ attn_mode: str,
877
+ split_attn: bool,
878
+ loading_device: Union[str, torch.device],
879
+ dit_weight_dtype: Optional[torch.dtype],
880
+ fp8_scaled: bool = False,
881
+ ) -> WanModel:
882
+ # dit_weight_dtype is None for fp8_scaled
883
+ assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
884
+
885
+ device = torch.device(device)
886
+ loading_device = torch.device(loading_device)
887
+
888
+ with init_empty_weights():
889
+ logger.info(f"Creating WanModel")
890
+ model = WanModel(
891
+ model_type="i2v" if config.i2v else "t2v",
892
+ dim=config.dim,
893
+ eps=config.eps,
894
+ ffn_dim=config.ffn_dim,
895
+ freq_dim=config.freq_dim,
896
+ in_dim=config.in_dim,
897
+ num_heads=config.num_heads,
898
+ num_layers=config.num_layers,
899
+ out_dim=config.out_dim,
900
+ text_len=config.text_len,
901
+ attn_mode=attn_mode,
902
+ split_attn=split_attn,
903
+ )
904
+ if dit_weight_dtype is not None:
905
+ model.to(dit_weight_dtype)
906
+
907
+ # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others)
908
+ wan_loading_device = torch.device("cpu") if fp8_scaled else loading_device
909
+ logger.info(f"Loading DiT model from {dit_path}, device={wan_loading_device}, dtype={dit_weight_dtype}")
910
+
911
+ # load model weights with the specified dtype or as is
912
+ sd = load_safetensors(dit_path, wan_loading_device, disable_mmap=True, dtype=dit_weight_dtype)
913
+
914
+ # remove "model.diffusion_model." prefix: 1.3B model has this prefix
915
+ for key in list(sd.keys()):
916
+ if key.startswith("model.diffusion_model."):
917
+ sd[key[22:]] = sd.pop(key)
918
+
919
+ if fp8_scaled:
920
+ # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap)
921
+ logger.info(f"Optimizing model weights to fp8. This may take a while.")
922
+ sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu")
923
+
924
+ if loading_device.type != "cpu":
925
+ # make sure all the model weights are on the loading_device
926
+ logger.info(f"Moving weights to {loading_device}")
927
+ for key in sd.keys():
928
+ sd[key] = sd[key].to(loading_device)
929
+
930
+ info = model.load_state_dict(sd, strict=True, assign=True)
931
+ logger.info(f"Loaded DiT model from {dit_path}, info={info}")
932
+
933
+ return model