Upload 15 files
Browse files- fpack_cache_latents.py +124 -54
- fpack_generate_video.py +438 -317
- fpack_train_network.py +617 -0
- hv_train.py +1721 -0
- hv_train_network.py +0 -0
- wan_cache_latents.py +177 -0
- wan_cache_text_encoder_outputs.py +107 -0
- wan_generate_video.py +1902 -0
- wan_train_network.py +444 -0
fpack_cache_latents.py
CHANGED
@@ -2,13 +2,14 @@ import argparse
|
|
2 |
import logging
|
3 |
import math
|
4 |
import os
|
5 |
-
from typing import List
|
6 |
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
from tqdm import tqdm
|
11 |
from transformers import SiglipImageProcessor, SiglipVisionModel
|
|
|
12 |
|
13 |
from dataset import config_utils
|
14 |
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
@@ -28,15 +29,20 @@ def encode_and_save_batch(
|
|
28 |
feature_extractor: SiglipImageProcessor,
|
29 |
image_encoder: SiglipVisionModel,
|
30 |
batch: List[ItemInfo],
|
31 |
-
latent_window_size: int,
|
32 |
vanilla_sampling: bool = False,
|
33 |
one_frame: bool = False,
|
|
|
|
|
34 |
):
|
35 |
"""Encode a batch of original RGB videos and save FramePack section caches."""
|
36 |
if one_frame:
|
37 |
-
encode_and_save_batch_one_frame(
|
|
|
|
|
38 |
return
|
39 |
|
|
|
|
|
40 |
# Stack batch into tensor (B,C,F,H,W) in RGB order
|
41 |
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
42 |
if len(contents.shape) == 4:
|
@@ -238,34 +244,68 @@ def encode_and_save_batch_one_frame(
|
|
238 |
feature_extractor: SiglipImageProcessor,
|
239 |
image_encoder: SiglipVisionModel,
|
240 |
batch: List[ItemInfo],
|
241 |
-
latent_window_size: int,
|
242 |
vanilla_sampling: bool = False,
|
|
|
|
|
243 |
):
|
244 |
# item.content: target image (H, W, C)
|
245 |
-
# item.control_content:
|
246 |
-
|
247 |
-
# Stack batch into tensor (B,F,H,W,C) in RGB order.
|
248 |
-
contents =
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
253 |
contents = contents.to(vae.device, dtype=vae.dtype)
|
254 |
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
255 |
|
256 |
-
height, width = contents.shape[
|
257 |
if height < 8 or width < 8:
|
258 |
item = batch[0] # other items should have the same size
|
259 |
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
260 |
|
261 |
-
# VAE encode
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
# Vision encoding per‑item (once): use control content because it is the start image
|
268 |
-
images = [item.control_content for item in batch] # list of [H, W, C]
|
269 |
|
270 |
# encode image with image encoder
|
271 |
image_embeddings = []
|
@@ -276,56 +316,74 @@ def encode_and_save_batch_one_frame(
|
|
276 |
image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
|
277 |
image_embeddings = image_embeddings.to("cpu") # Save memory
|
278 |
|
279 |
-
#
|
280 |
-
history_latents = torch.zeros(
|
281 |
-
(1, latents.shape[1], 1 + 2 + 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype
|
282 |
-
) # C=16 for HY
|
283 |
-
|
284 |
-
# indices generation (same as inference)
|
285 |
-
indices = torch.arange(0, sum([1, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
286 |
-
(
|
287 |
-
clean_latent_indices_pre, # Index for start_latent
|
288 |
-
latent_indices, # Indices for the target latents to predict
|
289 |
-
clean_latent_indices_post, # Index for the most recent history frame
|
290 |
-
clean_latent_2x_indices, # Indices for the next 2 history frames
|
291 |
-
clean_latent_4x_indices, # Indices for the next 16 history frames
|
292 |
-
) = indices.split([1, latent_window_size, 1, 2, 16], dim=1)
|
293 |
-
|
294 |
-
# Indices for clean_latents (start + recent history)
|
295 |
-
latent_indices = latent_indices[:, -1:] # Only the last index is used for one frame training
|
296 |
-
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
297 |
-
|
298 |
-
# clean latents preparation for all items (emulating inference)
|
299 |
-
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
|
300 |
-
|
301 |
for b, item in enumerate(batch):
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
# clean latents preparation (emulating inference)
|
305 |
-
|
306 |
-
|
|
|
|
|
307 |
|
308 |
# Target latents for this section (ground truth)
|
309 |
-
target_latents = latents[b :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
# save cache (file path is inside item.latent_cache_path pattern), remove batch dim
|
312 |
save_latent_cache_framepack(
|
313 |
item_info=item,
|
314 |
-
latent=target_latents
|
315 |
-
latent_indices=
|
316 |
-
clean_latents=clean_latents
|
317 |
-
clean_latent_indices=clean_latent_indices
|
318 |
-
clean_latents_2x=clean_latents_2x
|
319 |
-
clean_latent_2x_indices=clean_latent_2x_indices
|
320 |
-
clean_latents_4x=clean_latents_4x
|
321 |
-
clean_latent_4x_indices=clean_latent_4x_indices
|
322 |
image_embeddings=image_embeddings[b],
|
323 |
)
|
324 |
|
325 |
|
326 |
def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
327 |
parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
|
328 |
-
parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
|
329 |
parser.add_argument(
|
330 |
"--f1",
|
331 |
action="store_true",
|
@@ -336,6 +394,16 @@ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.Argument
|
|
336 |
action="store_true",
|
337 |
help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
|
338 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
return parser
|
340 |
|
341 |
|
@@ -373,7 +441,9 @@ def main(args: argparse.Namespace):
|
|
373 |
|
374 |
# encoding closure
|
375 |
def encode(batch: List[ItemInfo]):
|
376 |
-
encode_and_save_batch(
|
|
|
|
|
377 |
|
378 |
# reuse core loop from cache_latents with no change
|
379 |
encode_datasets_framepack(datasets, encode, args)
|
@@ -403,7 +473,7 @@ def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, arg
|
|
403 |
all_existing = os.path.exists(item.latent_cache_path)
|
404 |
else:
|
405 |
latent_f = (item.frame_count - 1) // 4 + 1
|
406 |
-
num_sections = max(1, math.floor((latent_f - 1) /
|
407 |
all_existing = True
|
408 |
for sec in range(num_sections):
|
409 |
p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
|
|
|
2 |
import logging
|
3 |
import math
|
4 |
import os
|
5 |
+
from typing import List, Optional
|
6 |
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
from tqdm import tqdm
|
11 |
from transformers import SiglipImageProcessor, SiglipVisionModel
|
12 |
+
from PIL import Image
|
13 |
|
14 |
from dataset import config_utils
|
15 |
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
|
|
29 |
feature_extractor: SiglipImageProcessor,
|
30 |
image_encoder: SiglipVisionModel,
|
31 |
batch: List[ItemInfo],
|
|
|
32 |
vanilla_sampling: bool = False,
|
33 |
one_frame: bool = False,
|
34 |
+
one_frame_no_2x: bool = False,
|
35 |
+
one_frame_no_4x: bool = False,
|
36 |
):
|
37 |
"""Encode a batch of original RGB videos and save FramePack section caches."""
|
38 |
if one_frame:
|
39 |
+
encode_and_save_batch_one_frame(
|
40 |
+
vae, feature_extractor, image_encoder, batch, vanilla_sampling, one_frame_no_2x, one_frame_no_4x
|
41 |
+
)
|
42 |
return
|
43 |
|
44 |
+
latent_window_size = batch[0].fp_latent_window_size # all items should have the same window size
|
45 |
+
|
46 |
# Stack batch into tensor (B,C,F,H,W) in RGB order
|
47 |
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
48 |
if len(contents.shape) == 4:
|
|
|
244 |
feature_extractor: SiglipImageProcessor,
|
245 |
image_encoder: SiglipVisionModel,
|
246 |
batch: List[ItemInfo],
|
|
|
247 |
vanilla_sampling: bool = False,
|
248 |
+
one_frame_no_2x: bool = False,
|
249 |
+
one_frame_no_4x: bool = False,
|
250 |
):
|
251 |
# item.content: target image (H, W, C)
|
252 |
+
# item.control_content: list of images (H, W, C)
|
253 |
+
|
254 |
+
# Stack batch into tensor (B,F,H,W,C) in RGB order. The numbers of control content for each item are the same.
|
255 |
+
contents = []
|
256 |
+
content_masks: list[list[Optional[torch.Tensor]]] = []
|
257 |
+
for item in batch:
|
258 |
+
item_contents = item.control_content + [item.content]
|
259 |
+
|
260 |
+
item_masks = []
|
261 |
+
for i, c in enumerate(item_contents):
|
262 |
+
if c.shape[-1] == 4: # RGBA
|
263 |
+
item_contents[i] = c[..., :3] # remove alpha channel from content
|
264 |
+
|
265 |
+
alpha = c[..., 3] # extract alpha channel
|
266 |
+
mask_image = Image.fromarray(alpha, mode="L")
|
267 |
+
width, height = mask_image.size
|
268 |
+
mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
|
269 |
+
mask_image = np.array(mask_image) # PIL to numpy, HWC
|
270 |
+
mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
|
271 |
+
mask_image = mask_image.squeeze(-1) # HWC -> HW
|
272 |
+
mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
|
273 |
+
mask_image = mask_image.to(torch.float32)
|
274 |
+
content_mask = mask_image
|
275 |
+
else:
|
276 |
+
content_mask = None
|
277 |
+
|
278 |
+
item_masks.append(content_mask)
|
279 |
+
|
280 |
+
item_contents = [torch.from_numpy(c) for c in item_contents]
|
281 |
+
contents.append(torch.stack(item_contents, dim=0)) # list of [F, H, W, C]
|
282 |
+
content_masks.append(item_masks)
|
283 |
+
|
284 |
+
contents = torch.stack(contents, dim=0) # B, F, H, W, C. F is control frames + target frame
|
285 |
|
286 |
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
287 |
contents = contents.to(vae.device, dtype=vae.dtype)
|
288 |
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
289 |
|
290 |
+
height, width = contents.shape[-2], contents.shape[-1]
|
291 |
if height < 8 or width < 8:
|
292 |
item = batch[0] # other items should have the same size
|
293 |
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
294 |
|
295 |
+
# VAE encode: we need to encode one frame at a time because VAE encoder has stride=4 for the time dimension except for the first frame.
|
296 |
+
latents = [hunyuan.vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])]
|
297 |
+
latents = torch.cat(latents, dim=2) # B, C, F, H/8, W/8
|
298 |
+
|
299 |
+
# apply alphas to latents
|
300 |
+
for b, item in enumerate(batch):
|
301 |
+
for i, content_mask in enumerate(content_masks[b]):
|
302 |
+
if content_mask is not None:
|
303 |
+
# apply mask to the latents
|
304 |
+
# print(f"Applying content mask for item {item.item_key}, frame {i}")
|
305 |
+
latents[b : b + 1, :, i : i + 1] *= content_mask
|
306 |
|
307 |
# Vision encoding per‑item (once): use control content because it is the start image
|
308 |
+
images = [item.control_content[0] for item in batch] # list of [H, W, C]
|
309 |
|
310 |
# encode image with image encoder
|
311 |
image_embeddings = []
|
|
|
316 |
image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
|
317 |
image_embeddings = image_embeddings.to("cpu") # Save memory
|
318 |
|
319 |
+
# save cache for each item in the batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
for b, item in enumerate(batch):
|
321 |
+
# indices generation (same as inference): each item may have different clean_latent_indices, so we generate them per item
|
322 |
+
clean_latent_indices = item.fp_1f_clean_indices # list of indices for clean latents
|
323 |
+
if clean_latent_indices is None or len(clean_latent_indices) == 0:
|
324 |
+
logger.warning(
|
325 |
+
f"Item {item.item_key} has no clean_latent_indices defined, using default indices for one frame training."
|
326 |
+
)
|
327 |
+
clean_latent_indices = [0]
|
328 |
+
|
329 |
+
if not item.fp_1f_no_post:
|
330 |
+
clean_latent_indices = clean_latent_indices + [1 + item.fp_latent_window_size]
|
331 |
+
clean_latent_indices = torch.Tensor(clean_latent_indices).long() # N
|
332 |
+
|
333 |
+
latent_index = torch.Tensor([item.fp_1f_target_index]).long() # 1
|
334 |
+
|
335 |
+
# zero values is not needed to cache even if one_frame_no_2x or 4x is False
|
336 |
+
clean_latents_2x = None
|
337 |
+
clean_latents_4x = None
|
338 |
+
|
339 |
+
if one_frame_no_2x:
|
340 |
+
clean_latent_2x_indices = None
|
341 |
+
else:
|
342 |
+
index = 1 + item.fp_latent_window_size + 1
|
343 |
+
clean_latent_2x_indices = torch.arange(index, index + 2) # 2
|
344 |
+
|
345 |
+
if one_frame_no_4x:
|
346 |
+
clean_latent_4x_indices = None
|
347 |
+
else:
|
348 |
+
index = 1 + item.fp_latent_window_size + 1 + 2
|
349 |
+
clean_latent_4x_indices = torch.arange(index, index + 16) # 16
|
350 |
|
351 |
# clean latents preparation (emulating inference)
|
352 |
+
clean_latents = latents[b, :, :-1] # C, F, H, W
|
353 |
+
if not item.fp_1f_no_post:
|
354 |
+
# If zero post is enabled, we need to add a zero frame at the end
|
355 |
+
clean_latents = F.pad(clean_latents, (0, 0, 0, 0, 0, 1), value=0.0) # C, F+1, H, W
|
356 |
|
357 |
# Target latents for this section (ground truth)
|
358 |
+
target_latents = latents[b, :, -1:] # C, 1, H, W
|
359 |
+
|
360 |
+
print(f"Saving cache for item {item.item_key} at {item.latent_cache_path}. no_post: {item.fp_1f_no_post}")
|
361 |
+
print(f" Clean latent indices: {clean_latent_indices}, latent index: {latent_index}")
|
362 |
+
print(f" Clean latents: {clean_latents.shape}, target latents: {target_latents.shape}")
|
363 |
+
print(f" Clean latents 2x indices: {clean_latent_2x_indices}, clean latents 4x indices: {clean_latent_4x_indices}")
|
364 |
+
print(
|
365 |
+
f" Clean latents 2x: {clean_latents_2x.shape if clean_latents_2x is not None else 'None'}, "
|
366 |
+
f"Clean latents 4x: {clean_latents_4x.shape if clean_latents_4x is not None else 'None'}"
|
367 |
+
)
|
368 |
+
print(f" Image embeddings: {image_embeddings[b].shape}")
|
369 |
|
370 |
# save cache (file path is inside item.latent_cache_path pattern), remove batch dim
|
371 |
save_latent_cache_framepack(
|
372 |
item_info=item,
|
373 |
+
latent=target_latents, # Ground truth for this section
|
374 |
+
latent_indices=latent_index, # Indices for the ground truth section
|
375 |
+
clean_latents=clean_latents, # Start frame + history placeholder
|
376 |
+
clean_latent_indices=clean_latent_indices, # Indices for start frame + history placeholder
|
377 |
+
clean_latents_2x=clean_latents_2x, # History placeholder
|
378 |
+
clean_latent_2x_indices=clean_latent_2x_indices, # Indices for history placeholder
|
379 |
+
clean_latents_4x=clean_latents_4x, # History placeholder
|
380 |
+
clean_latent_4x_indices=clean_latent_4x_indices, # Indices for history placeholder
|
381 |
image_embeddings=image_embeddings[b],
|
382 |
)
|
383 |
|
384 |
|
385 |
def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
386 |
parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
|
|
|
387 |
parser.add_argument(
|
388 |
"--f1",
|
389 |
action="store_true",
|
|
|
394 |
action="store_true",
|
395 |
help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
|
396 |
)
|
397 |
+
parser.add_argument(
|
398 |
+
"--one_frame_no_2x",
|
399 |
+
action="store_true",
|
400 |
+
help="Do not use clean_latents_2x and clean_latent_2x_indices for one frame training.",
|
401 |
+
)
|
402 |
+
parser.add_argument(
|
403 |
+
"--one_frame_no_4x",
|
404 |
+
action="store_true",
|
405 |
+
help="Do not use clean_latents_4x and clean_latent_4x_indices for one frame training.",
|
406 |
+
)
|
407 |
return parser
|
408 |
|
409 |
|
|
|
441 |
|
442 |
# encoding closure
|
443 |
def encode(batch: List[ItemInfo]):
|
444 |
+
encode_and_save_batch(
|
445 |
+
vae, feature_extractor, image_encoder, batch, args.f1, args.one_frame, args.one_frame_no_2x, args.one_frame_no_4x
|
446 |
+
)
|
447 |
|
448 |
# reuse core loop from cache_latents with no change
|
449 |
encode_datasets_framepack(datasets, encode, args)
|
|
|
473 |
all_existing = os.path.exists(item.latent_cache_path)
|
474 |
else:
|
475 |
latent_f = (item.frame_count - 1) // 4 + 1
|
476 |
+
num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size)) # min 1 section
|
477 |
all_existing = True
|
478 |
for sec in range(num_sections):
|
479 |
p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
|
fpack_generate_video.py
CHANGED
@@ -114,20 +114,17 @@ def parse_args() -> argparse.Namespace:
|
|
114 |
"--one_frame_inference",
|
115 |
type=str,
|
116 |
default=None,
|
117 |
-
help="one frame inference, default is None, comma separated values from '
|
118 |
)
|
119 |
parser.add_argument(
|
120 |
-
"--
|
121 |
-
type=str,
|
122 |
-
default=None,
|
123 |
-
help="path to image mask for one frame inference. If specified, it will be used as mask for input image.",
|
124 |
)
|
125 |
parser.add_argument(
|
126 |
-
"--
|
127 |
type=str,
|
128 |
default=None,
|
129 |
nargs="*",
|
130 |
-
help="path to
|
131 |
)
|
132 |
parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
|
133 |
parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
|
@@ -154,7 +151,7 @@ def parse_args() -> argparse.Namespace:
|
|
154 |
default=None,
|
155 |
help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
|
156 |
)
|
157 |
-
parser.add_argument("--end_image_path", type=str,
|
158 |
parser.add_argument(
|
159 |
"--latent_paddings",
|
160 |
type=str,
|
@@ -180,6 +177,16 @@ def parse_args() -> argparse.Namespace:
|
|
180 |
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
181 |
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
182 |
# parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
184 |
parser.add_argument(
|
185 |
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
@@ -248,9 +255,9 @@ def parse_prompt_line(line: str) -> Dict[str, Any]:
|
|
248 |
|
249 |
# Create dictionary of overrides
|
250 |
overrides = {"prompt": prompt}
|
251 |
-
# Initialize
|
252 |
-
overrides["
|
253 |
-
overrides["
|
254 |
|
255 |
for part in parts[1:]:
|
256 |
if not part.strip():
|
@@ -276,8 +283,8 @@ def parse_prompt_line(line: str) -> Dict[str, Any]:
|
|
276 |
# overrides["flow_shift"] = float(value)
|
277 |
elif option == "i":
|
278 |
overrides["image_path"] = value
|
279 |
-
elif option == "im":
|
280 |
-
|
281 |
# elif option == "cn":
|
282 |
# overrides["control_path"] = value
|
283 |
elif option == "n":
|
@@ -285,17 +292,19 @@ def parse_prompt_line(line: str) -> Dict[str, Any]:
|
|
285 |
elif option == "vs": # video_sections
|
286 |
overrides["video_sections"] = int(value)
|
287 |
elif option == "ei": # end_image_path
|
288 |
-
overrides["end_image_path"]
|
289 |
-
elif option == "
|
290 |
-
overrides["
|
|
|
|
|
291 |
elif option == "of": # one_frame_inference
|
292 |
overrides["one_frame_inference"] = value
|
293 |
|
294 |
-
# If no
|
295 |
-
if not overrides["
|
296 |
-
del overrides["
|
297 |
-
if not overrides["
|
298 |
-
del overrides["
|
299 |
|
300 |
return overrides
|
301 |
|
@@ -366,6 +375,13 @@ def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVid
|
|
366 |
|
367 |
# do not fp8 optimize because we will merge LoRA weights
|
368 |
model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
return model
|
370 |
|
371 |
|
@@ -558,30 +574,44 @@ def prepare_i2v_inputs(
|
|
558 |
|
559 |
# prepare image
|
560 |
def preprocess_image(image_path: str):
|
561 |
-
image = Image.open(image_path)
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
image_np = np.array(image) # PIL to numpy, HWC
|
564 |
|
565 |
image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
|
566 |
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
|
567 |
image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
|
568 |
-
return image_tensor, image_np
|
569 |
|
570 |
section_image_paths = parse_section_strings(args.image_path)
|
571 |
|
572 |
section_images = {}
|
573 |
for index, image_path in section_image_paths.items():
|
574 |
-
img_tensor, img_np = preprocess_image(image_path)
|
575 |
section_images[index] = (img_tensor, img_np)
|
576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
# check end images
|
578 |
-
if args.
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
|
|
|
|
583 |
else:
|
584 |
-
|
|
|
585 |
|
586 |
# configure negative prompt
|
587 |
n_prompt = args.negative_prompt if args.negative_prompt else ""
|
@@ -644,6 +674,7 @@ def prepare_i2v_inputs(
|
|
644 |
image_encoder.to(device)
|
645 |
|
646 |
# encode image with image encoder
|
|
|
647 |
section_image_encoder_last_hidden_states = {}
|
648 |
for index, (img_tensor, img_np) in section_images.items():
|
649 |
with torch.no_grad():
|
@@ -666,14 +697,14 @@ def prepare_i2v_inputs(
|
|
666 |
start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
|
667 |
section_start_latents[index] = start_latent
|
668 |
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
|
678 |
vae.to("cpu") # move VAE to CPU to save memory
|
679 |
clean_memory_on_device(device)
|
@@ -710,7 +741,7 @@ def prepare_i2v_inputs(
|
|
710 |
}
|
711 |
arg_c_img[index] = arg_c_img_i
|
712 |
|
713 |
-
return height, width, video_seconds, arg_c, arg_null, arg_c_img,
|
714 |
|
715 |
|
716 |
# def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
|
@@ -930,13 +961,15 @@ def generate(
|
|
930 |
if shared_models is not None:
|
931 |
# Use shared models and encoded data
|
932 |
vae = shared_models.get("vae")
|
933 |
-
height, width, video_seconds, context, context_null, context_img,
|
934 |
-
args, device, vae, shared_models
|
935 |
)
|
936 |
else:
|
937 |
# prepare inputs without shared models
|
938 |
vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
|
939 |
-
height, width, video_seconds, context, context_null, context_img,
|
|
|
|
|
940 |
|
941 |
if shared_models is None or "model" not in shared_models:
|
942 |
# load DiT model
|
@@ -986,294 +1019,231 @@ def generate(
|
|
986 |
for mode in args.one_frame_inference.split(","):
|
987 |
one_frame_inference.add(mode.strip())
|
988 |
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
# One can try to remove below trick and just
|
1006 |
-
# use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
|
1007 |
-
# 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
|
1008 |
-
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
1009 |
-
|
1010 |
-
if args.latent_paddings is not None:
|
1011 |
-
# parse user defined latent paddings
|
1012 |
-
user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
|
1013 |
-
if len(user_latent_paddings) < total_latent_sections:
|
1014 |
-
print(
|
1015 |
-
f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
|
1016 |
-
)
|
1017 |
-
print(f"Use default paddings instead for unspecified sections.")
|
1018 |
-
latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
|
1019 |
-
elif len(user_latent_paddings) > total_latent_sections:
|
1020 |
-
print(
|
1021 |
-
f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
|
1022 |
-
)
|
1023 |
-
print(f"Use only first {total_latent_sections} paddings instead.")
|
1024 |
-
latent_paddings = user_latent_paddings[:total_latent_sections]
|
1025 |
-
else:
|
1026 |
-
latent_paddings = user_latent_paddings
|
1027 |
else:
|
1028 |
-
|
1029 |
-
history_latents = torch.
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
latent_paddings = list(latent_paddings) # make sure it's a list
|
1034 |
-
for loop_index in range(total_latent_sections):
|
1035 |
-
latent_padding = latent_paddings[loop_index]
|
1036 |
|
|
|
1037 |
if not f1_mode:
|
1038 |
# Inverted Anti-drifting
|
1039 |
-
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
1059 |
-
|
1060 |
-
|
1061 |
-
|
1062 |
-
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
|
1067 |
-
|
1068 |
-
if not f1_mode:
|
1069 |
-
# Inverted Anti-drifting
|
1070 |
-
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
1071 |
-
(
|
1072 |
-
clean_latent_indices_pre,
|
1073 |
-
blank_indices,
|
1074 |
-
latent_indices,
|
1075 |
-
clean_latent_indices_post,
|
1076 |
-
clean_latent_2x_indices,
|
1077 |
-
clean_latent_4x_indices,
|
1078 |
-
) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
1079 |
-
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
1080 |
-
|
1081 |
-
clean_latents_pre = start_latent.to(history_latents)
|
1082 |
-
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
|
1083 |
-
[1, 2, 16], dim=2
|
1084 |
-
)
|
1085 |
-
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
1086 |
-
|
1087 |
-
if end_latents is not None:
|
1088 |
-
clean_latents = torch.cat([clean_latents_pre, history_latents[:, :, : len(end_latents)]], dim=2)
|
1089 |
-
clean_latent_indices_extended = torch.zeros(1, 1 + len(end_latents), dtype=clean_latent_indices.dtype)
|
1090 |
-
clean_latent_indices_extended[:, :2] = clean_latent_indices
|
1091 |
-
clean_latent_indices = clean_latent_indices_extended
|
1092 |
-
|
1093 |
-
else:
|
1094 |
-
# F1 mode
|
1095 |
-
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
1096 |
-
(
|
1097 |
-
clean_latent_indices_start,
|
1098 |
-
clean_latent_4x_indices,
|
1099 |
-
clean_latent_2x_indices,
|
1100 |
-
clean_latent_1x_indices,
|
1101 |
-
latent_indices,
|
1102 |
-
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
1103 |
-
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
1104 |
-
|
1105 |
-
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
|
1106 |
-
[16, 2, 1], dim=2
|
1107 |
-
)
|
1108 |
-
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
1109 |
-
|
1110 |
-
# if use_teacache:
|
1111 |
-
# transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
|
1112 |
-
# else:
|
1113 |
-
# transformer.initialize_teacache(enable_teacache=False)
|
1114 |
-
|
1115 |
-
# prepare conditioning inputs
|
1116 |
-
if section_index_from_last in context:
|
1117 |
-
prompt_index = section_index_from_last
|
1118 |
-
elif section_index in context:
|
1119 |
-
prompt_index = section_index
|
1120 |
else:
|
1121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1122 |
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
|
|
|
|
|
|
1126 |
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1130 |
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1134 |
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
# one frame inference
|
1143 |
-
latent_indices = latent_indices[:, -1:] # only use the last frame (default)
|
1144 |
-
sample_num_frames = 1
|
1145 |
-
|
1146 |
-
def get_latent_mask(mask_path: str):
|
1147 |
-
mask_image = Image.open(mask_path).convert("L") # grayscale
|
1148 |
-
mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
|
1149 |
-
mask_image = np.array(mask_image) # PIL to numpy, HWC
|
1150 |
-
mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
|
1151 |
-
mask_image = mask_image.squeeze(-1) # HWC -> HW
|
1152 |
-
mask_image = mask_image.unsqueeze(0).unsqueeze(0) # HW -> 11HW
|
1153 |
-
mask_image = mask_image.to(clean_latents)
|
1154 |
-
return mask_image
|
1155 |
-
|
1156 |
-
if args.image_mask_path is not None:
|
1157 |
-
mask_image = get_latent_mask(args.image_mask_path)
|
1158 |
-
logger.info(f"Apply mask for clean latents (start image): {args.image_mask_path}, shape: {mask_image.shape}")
|
1159 |
-
clean_latents[:, :, 0, :, :] = clean_latents[:, :, 0, :, :] * mask_image
|
1160 |
-
if args.end_image_mask_path is not None and len(args.end_image_mask_path) > 0:
|
1161 |
-
# # apply mask for clean latents 1x (end image)
|
1162 |
-
count = min(len(args.end_image_mask_path), len(end_latents))
|
1163 |
-
for i in range(count):
|
1164 |
-
mask_image = get_latent_mask(args.end_image_mask_path[i])
|
1165 |
-
logger.info(
|
1166 |
-
f"Apply mask for clean latents 1x (end image) for {i+1}: {args.end_image_mask_path[i]}, shape: {mask_image.shape}"
|
1167 |
-
)
|
1168 |
-
clean_latents[:, :, i + 1 : i + 2, :, :] = clean_latents[:, :, i + 1 : i + 2, :, :] * mask_image
|
1169 |
-
|
1170 |
-
for one_frame_param in one_frame_inference:
|
1171 |
-
if one_frame_param.startswith("target_index="):
|
1172 |
-
target_index = int(one_frame_param.split("=")[1])
|
1173 |
-
latent_indices[:, 0] = target_index
|
1174 |
-
logger.info(f"Set index for target: {target_index}")
|
1175 |
-
elif one_frame_param.startswith("start_index="):
|
1176 |
-
start_index = int(one_frame_param.split("=")[1])
|
1177 |
-
clean_latent_indices[:, 0] = start_index
|
1178 |
-
logger.info(f"Set index for clean latent pre (start image): {start_index}")
|
1179 |
-
elif one_frame_param.startswith("history_index="):
|
1180 |
-
history_indices = one_frame_param.split("=")[1].split(";")
|
1181 |
-
i = 0
|
1182 |
-
while i < len(history_indices) and i < len(end_latents):
|
1183 |
-
history_index = int(history_indices[i])
|
1184 |
-
clean_latent_indices[:, 1 + i] = history_index
|
1185 |
-
i += 1
|
1186 |
-
while i < len(end_latents):
|
1187 |
-
clean_latent_indices[:, 1 + i] = history_index
|
1188 |
-
i += 1
|
1189 |
-
logger.info(f"Set index for clean latent post (end image): {history_indices}")
|
1190 |
-
|
1191 |
-
if "no_2x" in one_frame_inference:
|
1192 |
-
clean_latents_2x = None
|
1193 |
-
clean_latent_2x_indices = None
|
1194 |
-
logger.info(f"No clean_latents_2x")
|
1195 |
-
if "no_4x" in one_frame_inference:
|
1196 |
-
clean_latents_4x = None
|
1197 |
-
clean_latent_4x_indices = None
|
1198 |
-
logger.info(f"No clean_latents_4x")
|
1199 |
-
if "no_post" in one_frame_inference:
|
1200 |
-
clean_latents = clean_latents[:, :, :1, :, :]
|
1201 |
-
clean_latent_indices = clean_latent_indices[:, :1]
|
1202 |
-
logger.info(f"No clean_latents post")
|
1203 |
-
elif "zero_post" in one_frame_inference:
|
1204 |
-
# zero out the history latents. this seems to prevent the images from corrupting
|
1205 |
-
clean_latents[:, :, 1:, :, :] = torch.zeros_like(clean_latents[:, :, 1:, :, :])
|
1206 |
-
logger.info(f"Zero out clean_latents post")
|
1207 |
|
1208 |
-
|
1209 |
-
|
1210 |
)
|
1211 |
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
|
1225 |
-
|
1226 |
-
|
1227 |
-
|
1228 |
-
|
1229 |
-
|
1230 |
-
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
# Inverted Anti-drifting: prepend generated latents to history latents
|
1246 |
-
if is_last_section:
|
1247 |
-
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
|
1248 |
-
total_generated_latent_frames += 1
|
1249 |
|
1250 |
-
|
1251 |
-
|
1252 |
-
|
1253 |
-
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
1257 |
-
logger.info(f"Generated. Latent shape {real_history_latents.shape}")
|
1258 |
-
|
1259 |
-
# # TODO support saving intermediate video
|
1260 |
-
# clean_memory_on_device(device)
|
1261 |
-
# vae.to(device)
|
1262 |
-
# if history_pixels is None:
|
1263 |
-
# history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
|
1264 |
-
# else:
|
1265 |
-
# section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
|
1266 |
-
# overlapped_frames = latent_window_size * 4 - 3
|
1267 |
-
# current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
|
1268 |
-
# history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
1269 |
-
# vae.to("cpu")
|
1270 |
-
# # if not is_last_section:
|
1271 |
-
# # # save intermediate video
|
1272 |
-
# # save_video(history_pixels[0], args, total_generated_latent_frames)
|
1273 |
-
# print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
|
1274 |
|
1275 |
-
|
1276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1277 |
|
1278 |
# Only clean up shared models if they were created within this function
|
1279 |
if shared_models is None:
|
@@ -1284,8 +1254,9 @@ def generate(
|
|
1284 |
model.to("cpu")
|
1285 |
|
1286 |
# wait for 5 seconds until block swap is done
|
1287 |
-
|
1288 |
-
|
|
|
1289 |
|
1290 |
gc.collect()
|
1291 |
clean_memory_on_device(device)
|
@@ -1293,6 +1264,156 @@ def generate(
|
|
1293 |
return vae, real_history_latents
|
1294 |
|
1295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1296 |
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
|
1297 |
"""Save latent to file
|
1298 |
|
|
|
114 |
"--one_frame_inference",
|
115 |
type=str,
|
116 |
default=None,
|
117 |
+
help="one frame inference, default is None, comma separated values from 'no_2x', 'no_4x', 'no_post', 'control_indices' and 'target_index'.",
|
118 |
)
|
119 |
parser.add_argument(
|
120 |
+
"--control_image_path", type=str, default=None, nargs="*", help="path to control (reference) image for one frame inference."
|
|
|
|
|
|
|
121 |
)
|
122 |
parser.add_argument(
|
123 |
+
"--control_image_mask_path",
|
124 |
type=str,
|
125 |
default=None,
|
126 |
nargs="*",
|
127 |
+
help="path to control (reference) image mask for one frame inference.",
|
128 |
)
|
129 |
parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
|
130 |
parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
|
|
|
151 |
default=None,
|
152 |
help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
|
153 |
)
|
154 |
+
parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
|
155 |
parser.add_argument(
|
156 |
"--latent_paddings",
|
157 |
type=str,
|
|
|
177 |
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
178 |
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
179 |
# parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
|
180 |
+
parser.add_argument(
|
181 |
+
"--rope_scaling_factor", type=float, default=0.5, help="RoPE scaling factor for high resolution (H/W), default is 0.5"
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--rope_scaling_timestep_threshold",
|
185 |
+
type=int,
|
186 |
+
default=None,
|
187 |
+
help="RoPE scaling timestep threshold, default is None (disable), if set, RoPE scaling will be applied only for timesteps >= threshold, around 800 is good starting point",
|
188 |
+
)
|
189 |
+
|
190 |
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
191 |
parser.add_argument(
|
192 |
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
|
|
255 |
|
256 |
# Create dictionary of overrides
|
257 |
overrides = {"prompt": prompt}
|
258 |
+
# Initialize control_image_path and control_image_mask_path as a list to accommodate multiple paths
|
259 |
+
overrides["control_image_path"] = []
|
260 |
+
overrides["control_image_mask_path"] = []
|
261 |
|
262 |
for part in parts[1:]:
|
263 |
if not part.strip():
|
|
|
283 |
# overrides["flow_shift"] = float(value)
|
284 |
elif option == "i":
|
285 |
overrides["image_path"] = value
|
286 |
+
# elif option == "im":
|
287 |
+
# overrides["image_mask_path"] = value
|
288 |
# elif option == "cn":
|
289 |
# overrides["control_path"] = value
|
290 |
elif option == "n":
|
|
|
292 |
elif option == "vs": # video_sections
|
293 |
overrides["video_sections"] = int(value)
|
294 |
elif option == "ei": # end_image_path
|
295 |
+
overrides["end_image_path"] = value
|
296 |
+
elif option == "ci": # control_image_path
|
297 |
+
overrides["control_image_path"].append(value)
|
298 |
+
elif option == "cim": # control_image_mask_path
|
299 |
+
overrides["control_image_mask_path"].append(value)
|
300 |
elif option == "of": # one_frame_inference
|
301 |
overrides["one_frame_inference"] = value
|
302 |
|
303 |
+
# If no control_image_path was provided, remove the empty list
|
304 |
+
if not overrides["control_image_path"]:
|
305 |
+
del overrides["control_image_path"]
|
306 |
+
if not overrides["control_image_mask_path"]:
|
307 |
+
del overrides["control_image_mask_path"]
|
308 |
|
309 |
return overrides
|
310 |
|
|
|
375 |
|
376 |
# do not fp8 optimize because we will merge LoRA weights
|
377 |
model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
|
378 |
+
|
379 |
+
# apply RoPE scaling factor
|
380 |
+
if args.rope_scaling_timestep_threshold is not None:
|
381 |
+
logger.info(
|
382 |
+
f"Applying RoPE scaling factor {args.rope_scaling_factor} for timesteps >= {args.rope_scaling_timestep_threshold}"
|
383 |
+
)
|
384 |
+
model.enable_rope_scaling(args.rope_scaling_timestep_threshold, args.rope_scaling_factor)
|
385 |
return model
|
386 |
|
387 |
|
|
|
574 |
|
575 |
# prepare image
|
576 |
def preprocess_image(image_path: str):
|
577 |
+
image = Image.open(image_path)
|
578 |
+
if image.mode == "RGBA":
|
579 |
+
alpha = image.split()[-1]
|
580 |
+
else:
|
581 |
+
alpha = None
|
582 |
+
image = image.convert("RGB")
|
583 |
|
584 |
image_np = np.array(image) # PIL to numpy, HWC
|
585 |
|
586 |
image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
|
587 |
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
|
588 |
image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
|
589 |
+
return image_tensor, image_np, alpha
|
590 |
|
591 |
section_image_paths = parse_section_strings(args.image_path)
|
592 |
|
593 |
section_images = {}
|
594 |
for index, image_path in section_image_paths.items():
|
595 |
+
img_tensor, img_np, _ = preprocess_image(image_path)
|
596 |
section_images[index] = (img_tensor, img_np)
|
597 |
|
598 |
+
# check end image
|
599 |
+
if args.end_image_path is not None:
|
600 |
+
end_image_tensor, _, _ = preprocess_image(args.end_image_path)
|
601 |
+
else:
|
602 |
+
end_image_tensor = None
|
603 |
+
|
604 |
# check end images
|
605 |
+
if args.control_image_path is not None and len(args.control_image_path) > 0:
|
606 |
+
control_image_tensors = []
|
607 |
+
control_mask_images = []
|
608 |
+
for ctrl_image_path in args.control_image_path:
|
609 |
+
control_image_tensor, _, control_mask = preprocess_image(ctrl_image_path)
|
610 |
+
control_image_tensors.append(control_image_tensor)
|
611 |
+
control_mask_images.append(control_mask)
|
612 |
else:
|
613 |
+
control_image_tensors = None
|
614 |
+
control_mask_images = None
|
615 |
|
616 |
# configure negative prompt
|
617 |
n_prompt = args.negative_prompt if args.negative_prompt else ""
|
|
|
674 |
image_encoder.to(device)
|
675 |
|
676 |
# encode image with image encoder
|
677 |
+
|
678 |
section_image_encoder_last_hidden_states = {}
|
679 |
for index, (img_tensor, img_np) in section_images.items():
|
680 |
with torch.no_grad():
|
|
|
697 |
start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
|
698 |
section_start_latents[index] = start_latent
|
699 |
|
700 |
+
end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
|
701 |
+
|
702 |
+
control_latents = None
|
703 |
+
if control_image_tensors is not None:
|
704 |
+
control_latents = []
|
705 |
+
for ctrl_image_tensor in control_image_tensors:
|
706 |
+
control_latent = hunyuan.vae_encode(ctrl_image_tensor, vae).cpu()
|
707 |
+
control_latents.append(control_latent)
|
708 |
|
709 |
vae.to("cpu") # move VAE to CPU to save memory
|
710 |
clean_memory_on_device(device)
|
|
|
741 |
}
|
742 |
arg_c_img[index] = arg_c_img_i
|
743 |
|
744 |
+
return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latent, control_latents, control_mask_images
|
745 |
|
746 |
|
747 |
# def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
|
|
|
961 |
if shared_models is not None:
|
962 |
# Use shared models and encoded data
|
963 |
vae = shared_models.get("vae")
|
964 |
+
height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
|
965 |
+
prepare_i2v_inputs(args, device, vae, shared_models)
|
966 |
)
|
967 |
else:
|
968 |
# prepare inputs without shared models
|
969 |
vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
|
970 |
+
height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
|
971 |
+
prepare_i2v_inputs(args, device, vae)
|
972 |
+
)
|
973 |
|
974 |
if shared_models is None or "model" not in shared_models:
|
975 |
# load DiT model
|
|
|
1019 |
for mode in args.one_frame_inference.split(","):
|
1020 |
one_frame_inference.add(mode.strip())
|
1021 |
|
1022 |
+
if one_frame_inference is not None:
|
1023 |
+
real_history_latents = generate_with_one_frame_inference(
|
1024 |
+
args,
|
1025 |
+
model,
|
1026 |
+
context,
|
1027 |
+
context_null,
|
1028 |
+
context_img,
|
1029 |
+
control_latents,
|
1030 |
+
control_mask_images,
|
1031 |
+
latent_window_size,
|
1032 |
+
height,
|
1033 |
+
width,
|
1034 |
+
device,
|
1035 |
+
seed_g,
|
1036 |
+
one_frame_inference,
|
1037 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1038 |
else:
|
1039 |
+
# prepare history latents
|
1040 |
+
history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
|
1041 |
+
if end_latent is not None and not f1_mode:
|
1042 |
+
logger.info(f"Use end image(s): {args.end_image_path}")
|
1043 |
+
history_latents[:, :, :1] = end_latent.to(history_latents)
|
|
|
|
|
|
|
1044 |
|
1045 |
+
# prepare clean latents and indices
|
1046 |
if not f1_mode:
|
1047 |
# Inverted Anti-drifting
|
1048 |
+
total_generated_latent_frames = 0
|
1049 |
+
latent_paddings = reversed(range(total_latent_sections))
|
1050 |
+
|
1051 |
+
if total_latent_sections > 4 and one_frame_inference is None:
|
1052 |
+
# In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
|
1053 |
+
# items looks better than expanding it when total_latent_sections > 4
|
1054 |
+
# One can try to remove below trick and just
|
1055 |
+
# use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
|
1056 |
+
# 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
|
1057 |
+
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
1058 |
+
|
1059 |
+
if args.latent_paddings is not None:
|
1060 |
+
# parse user defined latent paddings
|
1061 |
+
user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
|
1062 |
+
if len(user_latent_paddings) < total_latent_sections:
|
1063 |
+
print(
|
1064 |
+
f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
|
1065 |
+
)
|
1066 |
+
print(f"Use default paddings instead for unspecified sections.")
|
1067 |
+
latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
|
1068 |
+
elif len(user_latent_paddings) > total_latent_sections:
|
1069 |
+
print(
|
1070 |
+
f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
|
1071 |
+
)
|
1072 |
+
print(f"Use only first {total_latent_sections} paddings instead.")
|
1073 |
+
latent_paddings = user_latent_paddings[:total_latent_sections]
|
1074 |
+
else:
|
1075 |
+
latent_paddings = user_latent_paddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1076 |
else:
|
1077 |
+
start_latent = context_img[0]["start_latent"]
|
1078 |
+
history_latents = torch.cat([history_latents, start_latent], dim=2)
|
1079 |
+
total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
|
1080 |
+
latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
|
1081 |
+
|
1082 |
+
latent_paddings = list(latent_paddings) # make sure it's a list
|
1083 |
+
for loop_index in range(total_latent_sections):
|
1084 |
+
latent_padding = latent_paddings[loop_index]
|
1085 |
+
|
1086 |
+
if not f1_mode:
|
1087 |
+
# Inverted Anti-drifting
|
1088 |
+
section_index_reverse = loop_index # 0, 1, 2, 3
|
1089 |
+
section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
|
1090 |
+
section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
|
1091 |
+
|
1092 |
+
is_last_section = section_index == 0
|
1093 |
+
is_first_section = section_index_reverse == 0
|
1094 |
+
latent_padding_size = latent_padding * latent_window_size
|
1095 |
+
|
1096 |
+
logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
|
1097 |
+
else:
|
1098 |
+
section_index = loop_index # 0, 1, 2, 3
|
1099 |
+
section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
|
1100 |
+
is_last_section = loop_index == total_latent_sections - 1
|
1101 |
+
is_first_section = loop_index == 0
|
1102 |
+
latent_padding_size = 0 # dummy padding for F1 mode
|
1103 |
+
|
1104 |
+
# select start latent
|
1105 |
+
if section_index_from_last in context_img:
|
1106 |
+
image_index = section_index_from_last
|
1107 |
+
elif section_index in context_img:
|
1108 |
+
image_index = section_index
|
1109 |
+
else:
|
1110 |
+
image_index = 0
|
1111 |
|
1112 |
+
start_latent = context_img[image_index]["start_latent"]
|
1113 |
+
image_path = context_img[image_index]["image_path"]
|
1114 |
+
if image_index != 0: # use section image other than section 0
|
1115 |
+
logger.info(
|
1116 |
+
f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}"
|
1117 |
+
)
|
1118 |
|
1119 |
+
if not f1_mode:
|
1120 |
+
# Inverted Anti-drifting
|
1121 |
+
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
1122 |
+
(
|
1123 |
+
clean_latent_indices_pre,
|
1124 |
+
blank_indices,
|
1125 |
+
latent_indices,
|
1126 |
+
clean_latent_indices_post,
|
1127 |
+
clean_latent_2x_indices,
|
1128 |
+
clean_latent_4x_indices,
|
1129 |
+
) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
1130 |
+
|
1131 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
1132 |
+
|
1133 |
+
clean_latents_pre = start_latent.to(history_latents)
|
1134 |
+
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
|
1135 |
+
[1, 2, 16], dim=2
|
1136 |
+
)
|
1137 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
1138 |
|
1139 |
+
else:
|
1140 |
+
# F1 mode
|
1141 |
+
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
1142 |
+
(
|
1143 |
+
clean_latent_indices_start,
|
1144 |
+
clean_latent_4x_indices,
|
1145 |
+
clean_latent_2x_indices,
|
1146 |
+
clean_latent_1x_indices,
|
1147 |
+
latent_indices,
|
1148 |
+
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
1149 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
1150 |
+
|
1151 |
+
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
|
1152 |
+
[16, 2, 1], dim=2
|
1153 |
+
)
|
1154 |
+
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
1155 |
+
|
1156 |
+
# if use_teacache:
|
1157 |
+
# transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
|
1158 |
+
# else:
|
1159 |
+
# transformer.initialize_teacache(enable_teacache=False)
|
1160 |
+
|
1161 |
+
# prepare conditioning inputs
|
1162 |
+
if section_index_from_last in context:
|
1163 |
+
prompt_index = section_index_from_last
|
1164 |
+
elif section_index in context:
|
1165 |
+
prompt_index = section_index
|
1166 |
+
else:
|
1167 |
+
prompt_index = 0
|
1168 |
|
1169 |
+
context_for_index = context[prompt_index]
|
1170 |
+
# if args.section_prompts is not None:
|
1171 |
+
logger.info(f"Section {section_index}: {context_for_index['prompt']}")
|
1172 |
|
1173 |
+
llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
|
1174 |
+
llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
|
1175 |
+
clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1176 |
|
1177 |
+
image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
|
1178 |
+
device, dtype=torch.bfloat16
|
1179 |
)
|
1180 |
|
1181 |
+
llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
|
1182 |
+
llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
|
1183 |
+
clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
1184 |
+
|
1185 |
+
generated_latents = sample_hunyuan(
|
1186 |
+
transformer=model,
|
1187 |
+
sampler=args.sample_solver,
|
1188 |
+
width=width,
|
1189 |
+
height=height,
|
1190 |
+
frames=num_frames,
|
1191 |
+
real_guidance_scale=args.guidance_scale,
|
1192 |
+
distilled_guidance_scale=args.embedded_cfg_scale,
|
1193 |
+
guidance_rescale=args.guidance_rescale,
|
1194 |
+
# shift=3.0,
|
1195 |
+
num_inference_steps=args.infer_steps,
|
1196 |
+
generator=seed_g,
|
1197 |
+
prompt_embeds=llama_vec,
|
1198 |
+
prompt_embeds_mask=llama_attention_mask,
|
1199 |
+
prompt_poolers=clip_l_pooler,
|
1200 |
+
negative_prompt_embeds=llama_vec_n,
|
1201 |
+
negative_prompt_embeds_mask=llama_attention_mask_n,
|
1202 |
+
negative_prompt_poolers=clip_l_pooler_n,
|
1203 |
+
device=device,
|
1204 |
+
dtype=torch.bfloat16,
|
1205 |
+
image_embeddings=image_encoder_last_hidden_state,
|
1206 |
+
latent_indices=latent_indices,
|
1207 |
+
clean_latents=clean_latents,
|
1208 |
+
clean_latent_indices=clean_latent_indices,
|
1209 |
+
clean_latents_2x=clean_latents_2x,
|
1210 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
1211 |
+
clean_latents_4x=clean_latents_4x,
|
1212 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
1213 |
+
)
|
|
|
|
|
|
|
|
|
1214 |
|
1215 |
+
# concatenate generated latents
|
1216 |
+
total_generated_latent_frames += int(generated_latents.shape[2])
|
1217 |
+
if not f1_mode:
|
1218 |
+
# Inverted Anti-drifting: prepend generated latents to history latents
|
1219 |
+
if is_last_section:
|
1220 |
+
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
|
1221 |
+
total_generated_latent_frames += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1222 |
|
1223 |
+
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
|
1224 |
+
real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
|
1225 |
+
else:
|
1226 |
+
# F1 mode: append generated latents to history latents
|
1227 |
+
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
|
1228 |
+
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
|
1229 |
+
|
1230 |
+
logger.info(f"Generated. Latent shape {real_history_latents.shape}")
|
1231 |
+
|
1232 |
+
# # TODO support saving intermediate video
|
1233 |
+
# clean_memory_on_device(device)
|
1234 |
+
# vae.to(device)
|
1235 |
+
# if history_pixels is None:
|
1236 |
+
# history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
|
1237 |
+
# else:
|
1238 |
+
# section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
|
1239 |
+
# overlapped_frames = latent_window_size * 4 - 3
|
1240 |
+
# current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
|
1241 |
+
# history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
1242 |
+
# vae.to("cpu")
|
1243 |
+
# # if not is_last_section:
|
1244 |
+
# # # save intermediate video
|
1245 |
+
# # save_video(history_pixels[0], args, total_generated_latent_frames)
|
1246 |
+
# print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
|
1247 |
|
1248 |
# Only clean up shared models if they were created within this function
|
1249 |
if shared_models is None:
|
|
|
1254 |
model.to("cpu")
|
1255 |
|
1256 |
# wait for 5 seconds until block swap is done
|
1257 |
+
if args.blocks_to_swap > 0:
|
1258 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
1259 |
+
time.sleep(5)
|
1260 |
|
1261 |
gc.collect()
|
1262 |
clean_memory_on_device(device)
|
|
|
1264 |
return vae, real_history_latents
|
1265 |
|
1266 |
|
1267 |
+
def generate_with_one_frame_inference(
|
1268 |
+
args: argparse.Namespace,
|
1269 |
+
model: HunyuanVideoTransformer3DModelPacked,
|
1270 |
+
context: Dict[int, Dict[str, torch.Tensor]],
|
1271 |
+
context_null: Dict[str, torch.Tensor],
|
1272 |
+
context_img: Dict[int, Dict[str, torch.Tensor]],
|
1273 |
+
control_latents: Optional[List[torch.Tensor]],
|
1274 |
+
control_mask_images: Optional[List[Optional[Image.Image]]],
|
1275 |
+
latent_window_size: int,
|
1276 |
+
height: int,
|
1277 |
+
width: int,
|
1278 |
+
device: torch.device,
|
1279 |
+
seed_g: torch.Generator,
|
1280 |
+
one_frame_inference: set[str],
|
1281 |
+
) -> torch.Tensor:
|
1282 |
+
# one frame inference
|
1283 |
+
sample_num_frames = 1
|
1284 |
+
latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
|
1285 |
+
latent_indices[:, 0] = latent_window_size # last of latent_window
|
1286 |
+
|
1287 |
+
def get_latent_mask(mask_image: Image.Image) -> torch.Tensor:
|
1288 |
+
if mask_image.mode != "L":
|
1289 |
+
mask_image = mask_image.convert("L")
|
1290 |
+
mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
|
1291 |
+
mask_image = np.array(mask_image) # PIL to numpy, HWC
|
1292 |
+
mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
|
1293 |
+
mask_image = mask_image.squeeze(-1) # HWC -> HW
|
1294 |
+
mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
|
1295 |
+
mask_image = mask_image.to(torch.float32)
|
1296 |
+
return mask_image
|
1297 |
+
|
1298 |
+
if control_latents is None or len(control_latents) == 0:
|
1299 |
+
logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
|
1300 |
+
control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
|
1301 |
+
|
1302 |
+
if "no_post" not in one_frame_inference:
|
1303 |
+
# add zero latents as clean latents post
|
1304 |
+
control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
|
1305 |
+
logger.info(f"Add zero latents as clean latents post for one frame inference.")
|
1306 |
+
|
1307 |
+
# kisekaeichi and 1f-mc: both are using control images, but indices are different
|
1308 |
+
clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
|
1309 |
+
clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
|
1310 |
+
if "no_post" not in one_frame_inference:
|
1311 |
+
clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
|
1312 |
+
|
1313 |
+
for i in range(len(control_latents)):
|
1314 |
+
mask_image = None
|
1315 |
+
if args.control_image_mask_path is not None and i < len(args.control_image_mask_path):
|
1316 |
+
mask_image = get_latent_mask(Image.open(args.control_image_mask_path[i]))
|
1317 |
+
logger.info(
|
1318 |
+
f"Apply mask for clean latents 1x for {i + 1}: {args.control_image_mask_path[i]}, shape: {mask_image.shape}"
|
1319 |
+
)
|
1320 |
+
elif control_mask_images is not None and i < len(control_mask_images) and control_mask_images[i] is not None:
|
1321 |
+
mask_image = get_latent_mask(control_mask_images[i])
|
1322 |
+
logger.info(f"Apply mask for clean latents 1x for {i + 1} with alpha channel: {mask_image.shape}")
|
1323 |
+
if mask_image is not None:
|
1324 |
+
clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * mask_image
|
1325 |
+
|
1326 |
+
for one_frame_param in one_frame_inference:
|
1327 |
+
if one_frame_param.startswith("target_index="):
|
1328 |
+
target_index = int(one_frame_param.split("=")[1])
|
1329 |
+
latent_indices[:, 0] = target_index
|
1330 |
+
logger.info(f"Set index for target: {target_index}")
|
1331 |
+
elif one_frame_param.startswith("control_index="):
|
1332 |
+
control_indices = one_frame_param.split("=")[1].split(";")
|
1333 |
+
i = 0
|
1334 |
+
while i < len(control_indices) and i < clean_latent_indices.shape[1]:
|
1335 |
+
control_index = int(control_indices[i])
|
1336 |
+
clean_latent_indices[:, i] = control_index
|
1337 |
+
i += 1
|
1338 |
+
logger.info(f"Set index for clean latent 1x: {control_indices}")
|
1339 |
+
|
1340 |
+
# "default" option does nothing, so we can skip it
|
1341 |
+
if "default" in one_frame_inference:
|
1342 |
+
pass
|
1343 |
+
|
1344 |
+
if "no_2x" in one_frame_inference:
|
1345 |
+
clean_latents_2x = None
|
1346 |
+
clean_latent_2x_indices = None
|
1347 |
+
logger.info(f"No clean_latents_2x")
|
1348 |
+
else:
|
1349 |
+
clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
|
1350 |
+
index = 1 + latent_window_size + 1
|
1351 |
+
clean_latent_2x_indices = torch.arange(index, index + 2).unsqueeze(0) # 2
|
1352 |
+
|
1353 |
+
if "no_4x" in one_frame_inference:
|
1354 |
+
clean_latents_4x = None
|
1355 |
+
clean_latent_4x_indices = None
|
1356 |
+
logger.info(f"No clean_latents_4x")
|
1357 |
+
else:
|
1358 |
+
clean_latents_4x = torch.zeros((1, 16, 16, height // 8, width // 8), dtype=torch.float32)
|
1359 |
+
index = 1 + latent_window_size + 1 + 2
|
1360 |
+
clean_latent_4x_indices = torch.arange(index, index + 16).unsqueeze(0) # 16
|
1361 |
+
|
1362 |
+
logger.info(
|
1363 |
+
f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
|
1364 |
+
)
|
1365 |
+
|
1366 |
+
# prepare conditioning inputs
|
1367 |
+
prompt_index = 0
|
1368 |
+
image_index = 0
|
1369 |
+
|
1370 |
+
context_for_index = context[prompt_index]
|
1371 |
+
logger.info(f"Prompt: {context_for_index['prompt']}")
|
1372 |
+
|
1373 |
+
llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
|
1374 |
+
llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
|
1375 |
+
clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
1376 |
+
|
1377 |
+
image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(device, dtype=torch.bfloat16)
|
1378 |
+
|
1379 |
+
llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
|
1380 |
+
llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
|
1381 |
+
clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
1382 |
+
|
1383 |
+
generated_latents = sample_hunyuan(
|
1384 |
+
transformer=model,
|
1385 |
+
sampler=args.sample_solver,
|
1386 |
+
width=width,
|
1387 |
+
height=height,
|
1388 |
+
frames=1,
|
1389 |
+
real_guidance_scale=args.guidance_scale,
|
1390 |
+
distilled_guidance_scale=args.embedded_cfg_scale,
|
1391 |
+
guidance_rescale=args.guidance_rescale,
|
1392 |
+
# shift=3.0,
|
1393 |
+
num_inference_steps=args.infer_steps,
|
1394 |
+
generator=seed_g,
|
1395 |
+
prompt_embeds=llama_vec,
|
1396 |
+
prompt_embeds_mask=llama_attention_mask,
|
1397 |
+
prompt_poolers=clip_l_pooler,
|
1398 |
+
negative_prompt_embeds=llama_vec_n,
|
1399 |
+
negative_prompt_embeds_mask=llama_attention_mask_n,
|
1400 |
+
negative_prompt_poolers=clip_l_pooler_n,
|
1401 |
+
device=device,
|
1402 |
+
dtype=torch.bfloat16,
|
1403 |
+
image_embeddings=image_encoder_last_hidden_state,
|
1404 |
+
latent_indices=latent_indices,
|
1405 |
+
clean_latents=clean_latents,
|
1406 |
+
clean_latent_indices=clean_latent_indices,
|
1407 |
+
clean_latents_2x=clean_latents_2x,
|
1408 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
1409 |
+
clean_latents_4x=clean_latents_4x,
|
1410 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
1411 |
+
)
|
1412 |
+
|
1413 |
+
real_history_latents = generated_latents.to(clean_latents)
|
1414 |
+
return real_history_latents
|
1415 |
+
|
1416 |
+
|
1417 |
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
|
1418 |
"""Save latent to file
|
1419 |
|
fpack_train_network.py
ADDED
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
from typing import Optional
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
from tqdm import tqdm
|
13 |
+
from accelerate import Accelerator, init_empty_weights
|
14 |
+
|
15 |
+
from dataset import image_video_dataset
|
16 |
+
from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ARCHITECTURE_FRAMEPACK_FULL, load_video
|
17 |
+
from fpack_generate_video import decode_latent
|
18 |
+
from frame_pack import hunyuan
|
19 |
+
from frame_pack.clip_vision import hf_clip_vision_encode
|
20 |
+
from frame_pack.framepack_utils import load_image_encoders, load_text_encoder1, load_text_encoder2
|
21 |
+
from frame_pack.framepack_utils import load_vae as load_framepack_vae
|
22 |
+
from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
|
23 |
+
from frame_pack.k_diffusion_hunyuan import sample_hunyuan
|
24 |
+
from frame_pack.utils import crop_or_pad_yield_mask
|
25 |
+
from dataset.image_video_dataset import resize_image_to_bucket
|
26 |
+
from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
|
27 |
+
|
28 |
+
import logging
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
logging.basicConfig(level=logging.INFO)
|
32 |
+
|
33 |
+
from utils import model_utils
|
34 |
+
from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
|
35 |
+
|
36 |
+
|
37 |
+
class FramePackNetworkTrainer(NetworkTrainer):
|
38 |
+
def __init__(self):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
# region model specific
|
42 |
+
|
43 |
+
@property
|
44 |
+
def architecture(self) -> str:
|
45 |
+
return ARCHITECTURE_FRAMEPACK
|
46 |
+
|
47 |
+
@property
|
48 |
+
def architecture_full_name(self) -> str:
|
49 |
+
return ARCHITECTURE_FRAMEPACK_FULL
|
50 |
+
|
51 |
+
def handle_model_specific_args(self, args):
|
52 |
+
self._i2v_training = True
|
53 |
+
self._control_training = False
|
54 |
+
self.default_guidance_scale = 10.0 # embeded guidance scale
|
55 |
+
|
56 |
+
def process_sample_prompts(
|
57 |
+
self,
|
58 |
+
args: argparse.Namespace,
|
59 |
+
accelerator: Accelerator,
|
60 |
+
sample_prompts: str,
|
61 |
+
):
|
62 |
+
device = accelerator.device
|
63 |
+
|
64 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
|
65 |
+
prompts = load_prompts(sample_prompts)
|
66 |
+
|
67 |
+
# load text encoder
|
68 |
+
tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
|
69 |
+
tokenizer2, text_encoder2 = load_text_encoder2(args)
|
70 |
+
text_encoder2.to(device)
|
71 |
+
|
72 |
+
sample_prompts_te_outputs = {} # (prompt) -> (t1 embeds, t1 mask, t2 embeds)
|
73 |
+
for prompt_dict in prompts:
|
74 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
75 |
+
if p is None or p in sample_prompts_te_outputs:
|
76 |
+
continue
|
77 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
78 |
+
with torch.amp.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
|
79 |
+
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(p, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
|
80 |
+
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
|
81 |
+
|
82 |
+
llama_vec = llama_vec.to("cpu")
|
83 |
+
llama_attention_mask = llama_attention_mask.to("cpu")
|
84 |
+
clip_l_pooler = clip_l_pooler.to("cpu")
|
85 |
+
sample_prompts_te_outputs[p] = (llama_vec, llama_attention_mask, clip_l_pooler)
|
86 |
+
del text_encoder1, text_encoder2
|
87 |
+
clean_memory_on_device(device)
|
88 |
+
|
89 |
+
# image embedding for I2V training
|
90 |
+
feature_extractor, image_encoder = load_image_encoders(args)
|
91 |
+
image_encoder.to(device)
|
92 |
+
|
93 |
+
# encode image with image encoder
|
94 |
+
sample_prompts_image_embs = {}
|
95 |
+
for prompt_dict in prompts:
|
96 |
+
image_path = prompt_dict.get("image_path", None)
|
97 |
+
assert image_path is not None, "image_path should be set for I2V training"
|
98 |
+
if image_path in sample_prompts_image_embs:
|
99 |
+
continue
|
100 |
+
|
101 |
+
logger.info(f"Encoding image to image encoder context: {image_path}")
|
102 |
+
|
103 |
+
height = prompt_dict.get("height", 256)
|
104 |
+
width = prompt_dict.get("width", 256)
|
105 |
+
|
106 |
+
img = Image.open(image_path).convert("RGB")
|
107 |
+
img_np = np.array(img) # PIL to numpy, HWC
|
108 |
+
img_np = image_video_dataset.resize_image_to_bucket(img_np, (width, height)) # returns a numpy array
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
|
112 |
+
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
|
113 |
+
|
114 |
+
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to("cpu")
|
115 |
+
sample_prompts_image_embs[image_path] = image_encoder_last_hidden_state
|
116 |
+
|
117 |
+
del image_encoder
|
118 |
+
clean_memory_on_device(device)
|
119 |
+
|
120 |
+
# prepare sample parameters
|
121 |
+
sample_parameters = []
|
122 |
+
for prompt_dict in prompts:
|
123 |
+
prompt_dict_copy = prompt_dict.copy()
|
124 |
+
|
125 |
+
p = prompt_dict.get("prompt", "")
|
126 |
+
llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
|
127 |
+
prompt_dict_copy["llama_vec"] = llama_vec
|
128 |
+
prompt_dict_copy["llama_attention_mask"] = llama_attention_mask
|
129 |
+
prompt_dict_copy["clip_l_pooler"] = clip_l_pooler
|
130 |
+
|
131 |
+
p = prompt_dict.get("negative_prompt", "")
|
132 |
+
llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
|
133 |
+
prompt_dict_copy["negative_llama_vec"] = llama_vec
|
134 |
+
prompt_dict_copy["negative_llama_attention_mask"] = llama_attention_mask
|
135 |
+
prompt_dict_copy["negative_clip_l_pooler"] = clip_l_pooler
|
136 |
+
|
137 |
+
p = prompt_dict.get("image_path", None)
|
138 |
+
prompt_dict_copy["image_encoder_last_hidden_state"] = sample_prompts_image_embs[p]
|
139 |
+
|
140 |
+
sample_parameters.append(prompt_dict_copy)
|
141 |
+
|
142 |
+
clean_memory_on_device(accelerator.device)
|
143 |
+
return sample_parameters
|
144 |
+
|
145 |
+
def do_inference(
|
146 |
+
self,
|
147 |
+
accelerator,
|
148 |
+
args,
|
149 |
+
sample_parameter,
|
150 |
+
vae,
|
151 |
+
dit_dtype,
|
152 |
+
transformer,
|
153 |
+
discrete_flow_shift,
|
154 |
+
sample_steps,
|
155 |
+
width,
|
156 |
+
height,
|
157 |
+
frame_count,
|
158 |
+
generator,
|
159 |
+
do_classifier_free_guidance,
|
160 |
+
guidance_scale,
|
161 |
+
cfg_scale,
|
162 |
+
image_path=None,
|
163 |
+
control_video_path=None,
|
164 |
+
):
|
165 |
+
"""architecture dependent inference"""
|
166 |
+
model: HunyuanVideoTransformer3DModelPacked = transformer
|
167 |
+
device = accelerator.device
|
168 |
+
if cfg_scale is None:
|
169 |
+
cfg_scale = 1.0
|
170 |
+
do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
|
171 |
+
|
172 |
+
# prepare parameters
|
173 |
+
one_frame_mode = args.one_frame
|
174 |
+
if one_frame_mode:
|
175 |
+
one_frame_inference = set()
|
176 |
+
for mode in sample_parameter["one_frame"].split(","):
|
177 |
+
one_frame_inference.add(mode.strip())
|
178 |
+
else:
|
179 |
+
one_frame_inference = None
|
180 |
+
|
181 |
+
latent_window_size = args.latent_window_size # default is 9
|
182 |
+
latent_f = (frame_count - 1) // 4 + 1
|
183 |
+
total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
|
184 |
+
if total_latent_sections < 1 and not one_frame_mode:
|
185 |
+
logger.warning(f"Not enough frames for FramePack: {latent_f}, minimum: {latent_window_size*4+1}")
|
186 |
+
return None
|
187 |
+
|
188 |
+
latent_f = total_latent_sections * latent_window_size + 1
|
189 |
+
actual_frame_count = (latent_f - 1) * 4 + 1
|
190 |
+
if actual_frame_count != frame_count:
|
191 |
+
logger.info(f"Frame count mismatch: {actual_frame_count} != {frame_count}, trimming to {actual_frame_count}")
|
192 |
+
frame_count = actual_frame_count
|
193 |
+
num_frames = latent_window_size * 4 - 3
|
194 |
+
|
195 |
+
# prepare start and control latent
|
196 |
+
def encode_image(path):
|
197 |
+
image = Image.open(path)
|
198 |
+
if image.mode == "RGBA":
|
199 |
+
alpha = image.split()[-1]
|
200 |
+
image = image.convert("RGB")
|
201 |
+
else:
|
202 |
+
alpha = None
|
203 |
+
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
|
204 |
+
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).unsqueeze(0).float() # 1, C, 1, H, W
|
205 |
+
image = image / 127.5 - 1 # -1 to 1
|
206 |
+
return hunyuan.vae_encode(image, vae).to("cpu"), alpha
|
207 |
+
|
208 |
+
# VAE encoding
|
209 |
+
logger.info(f"Encoding image to latent space")
|
210 |
+
vae.to(device)
|
211 |
+
|
212 |
+
start_latent, _ = (
|
213 |
+
encode_image(image_path) if image_path else torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32)
|
214 |
+
)
|
215 |
+
|
216 |
+
if one_frame_mode:
|
217 |
+
control_latents = []
|
218 |
+
control_alphas = []
|
219 |
+
if "control_image_path" in sample_parameter:
|
220 |
+
for control_image_path in sample_parameter["control_image_path"]:
|
221 |
+
control_latent, control_alpha = encode_image(control_image_path)
|
222 |
+
control_latents.append(control_latent)
|
223 |
+
control_alphas.append(control_alpha)
|
224 |
+
else:
|
225 |
+
control_latents = None
|
226 |
+
control_alphas = None
|
227 |
+
|
228 |
+
vae.to("cpu") # move VAE to CPU to save memory
|
229 |
+
clean_memory_on_device(device)
|
230 |
+
|
231 |
+
# sampilng
|
232 |
+
if not one_frame_mode:
|
233 |
+
f1_mode = args.f1
|
234 |
+
history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
|
235 |
+
|
236 |
+
if not f1_mode:
|
237 |
+
total_generated_latent_frames = 0
|
238 |
+
latent_paddings = reversed(range(total_latent_sections))
|
239 |
+
else:
|
240 |
+
total_generated_latent_frames = 1
|
241 |
+
history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
|
242 |
+
latent_paddings = [0] * total_latent_sections
|
243 |
+
|
244 |
+
if total_latent_sections > 4:
|
245 |
+
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
246 |
+
|
247 |
+
latent_paddings = list(latent_paddings)
|
248 |
+
for loop_index in range(total_latent_sections):
|
249 |
+
latent_padding = latent_paddings[loop_index]
|
250 |
+
|
251 |
+
if not f1_mode:
|
252 |
+
is_last_section = latent_padding == 0
|
253 |
+
latent_padding_size = latent_padding * latent_window_size
|
254 |
+
|
255 |
+
logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
|
256 |
+
|
257 |
+
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
258 |
+
(
|
259 |
+
clean_latent_indices_pre,
|
260 |
+
blank_indices,
|
261 |
+
latent_indices,
|
262 |
+
clean_latent_indices_post,
|
263 |
+
clean_latent_2x_indices,
|
264 |
+
clean_latent_4x_indices,
|
265 |
+
) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
266 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
267 |
+
|
268 |
+
clean_latents_pre = start_latent.to(history_latents)
|
269 |
+
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
|
270 |
+
[1, 2, 16], dim=2
|
271 |
+
)
|
272 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
273 |
+
else:
|
274 |
+
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
275 |
+
(
|
276 |
+
clean_latent_indices_start,
|
277 |
+
clean_latent_4x_indices,
|
278 |
+
clean_latent_2x_indices,
|
279 |
+
clean_latent_1x_indices,
|
280 |
+
latent_indices,
|
281 |
+
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
282 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
283 |
+
|
284 |
+
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
|
285 |
+
[16, 2, 1], dim=2
|
286 |
+
)
|
287 |
+
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
288 |
+
|
289 |
+
# if use_teacache:
|
290 |
+
# transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
|
291 |
+
# else:
|
292 |
+
# transformer.initialize_teacache(enable_teacache=False)
|
293 |
+
|
294 |
+
llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
|
295 |
+
llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
|
296 |
+
clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
297 |
+
if cfg_scale == 1.0:
|
298 |
+
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
|
299 |
+
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
|
300 |
+
else:
|
301 |
+
llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
|
302 |
+
llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
|
303 |
+
clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
304 |
+
image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
|
305 |
+
device, dtype=torch.bfloat16
|
306 |
+
)
|
307 |
+
|
308 |
+
generated_latents = sample_hunyuan(
|
309 |
+
transformer=model,
|
310 |
+
sampler=args.sample_solver,
|
311 |
+
width=width,
|
312 |
+
height=height,
|
313 |
+
frames=num_frames,
|
314 |
+
real_guidance_scale=cfg_scale,
|
315 |
+
distilled_guidance_scale=guidance_scale,
|
316 |
+
guidance_rescale=0.0,
|
317 |
+
# shift=3.0,
|
318 |
+
num_inference_steps=sample_steps,
|
319 |
+
generator=generator,
|
320 |
+
prompt_embeds=llama_vec,
|
321 |
+
prompt_embeds_mask=llama_attention_mask,
|
322 |
+
prompt_poolers=clip_l_pooler,
|
323 |
+
negative_prompt_embeds=llama_vec_n,
|
324 |
+
negative_prompt_embeds_mask=llama_attention_mask_n,
|
325 |
+
negative_prompt_poolers=clip_l_pooler_n,
|
326 |
+
device=device,
|
327 |
+
dtype=torch.bfloat16,
|
328 |
+
image_embeddings=image_encoder_last_hidden_state,
|
329 |
+
latent_indices=latent_indices,
|
330 |
+
clean_latents=clean_latents,
|
331 |
+
clean_latent_indices=clean_latent_indices,
|
332 |
+
clean_latents_2x=clean_latents_2x,
|
333 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
334 |
+
clean_latents_4x=clean_latents_4x,
|
335 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
336 |
+
)
|
337 |
+
|
338 |
+
total_generated_latent_frames += int(generated_latents.shape[2])
|
339 |
+
if not f1_mode:
|
340 |
+
if is_last_section:
|
341 |
+
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
|
342 |
+
total_generated_latent_frames += 1
|
343 |
+
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
|
344 |
+
real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
|
345 |
+
else:
|
346 |
+
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
|
347 |
+
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
|
348 |
+
|
349 |
+
logger.info(f"Generated. Latent shape {real_history_latents.shape}")
|
350 |
+
else:
|
351 |
+
# one frame mode
|
352 |
+
sample_num_frames = 1
|
353 |
+
latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
|
354 |
+
latent_indices[:, 0] = latent_window_size # last of latent_window
|
355 |
+
|
356 |
+
def get_latent_mask(mask_image: Image.Image):
|
357 |
+
mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
|
358 |
+
mask_image = np.array(mask_image) # PIL to numpy, HWC
|
359 |
+
mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
|
360 |
+
mask_image = mask_image.squeeze(-1) # HWC -> HW
|
361 |
+
mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (B, C, F, H, W)
|
362 |
+
mask_image = mask_image.to(torch.float32)
|
363 |
+
return mask_image
|
364 |
+
|
365 |
+
if control_latents is None or len(control_latents) == 0:
|
366 |
+
logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
|
367 |
+
control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
|
368 |
+
|
369 |
+
if "no_post" not in one_frame_inference:
|
370 |
+
# add zero latents as clean latents post
|
371 |
+
control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
|
372 |
+
logger.info(f"Add zero latents as clean latents post for one frame inference.")
|
373 |
+
|
374 |
+
# kisekaeichi and 1f-mc: both are using control images, but indices are different
|
375 |
+
clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
|
376 |
+
clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
|
377 |
+
if "no_post" not in one_frame_inference:
|
378 |
+
clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
|
379 |
+
|
380 |
+
# apply mask for control latents (clean latents)
|
381 |
+
for i in range(len(control_alphas)):
|
382 |
+
control_alpha = control_alphas[i]
|
383 |
+
if control_alpha is not None:
|
384 |
+
latent_mask = get_latent_mask(control_alpha)
|
385 |
+
logger.info(
|
386 |
+
f"Apply mask for clean latents 1x for {i+1}: shape: {latent_mask.shape}"
|
387 |
+
)
|
388 |
+
clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * latent_mask
|
389 |
+
|
390 |
+
for one_frame_param in one_frame_inference:
|
391 |
+
if one_frame_param.startswith("target_index="):
|
392 |
+
target_index = int(one_frame_param.split("=")[1])
|
393 |
+
latent_indices[:, 0] = target_index
|
394 |
+
logger.info(f"Set index for target: {target_index}")
|
395 |
+
elif one_frame_param.startswith("control_index="):
|
396 |
+
control_indices = one_frame_param.split("=")[1].split(";")
|
397 |
+
i = 0
|
398 |
+
while i < len(control_indices) and i < clean_latent_indices.shape[1]:
|
399 |
+
control_index = int(control_indices[i])
|
400 |
+
clean_latent_indices[:, i] = control_index
|
401 |
+
i += 1
|
402 |
+
logger.info(f"Set index for clean latent 1x: {control_indices}")
|
403 |
+
|
404 |
+
if "no_2x" in one_frame_inference:
|
405 |
+
clean_latents_2x = None
|
406 |
+
clean_latent_2x_indices = None
|
407 |
+
logger.info(f"No clean_latents_2x")
|
408 |
+
else:
|
409 |
+
clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
|
410 |
+
index = 1 + latent_window_size + 1
|
411 |
+
clean_latent_2x_indices = torch.arange(index, index + 2) # 2
|
412 |
+
|
413 |
+
if "no_4x" in one_frame_inference:
|
414 |
+
clean_latents_4x = None
|
415 |
+
clean_latent_4x_indices = None
|
416 |
+
logger.info(f"No clean_latents_4x")
|
417 |
+
else:
|
418 |
+
index = 1 + latent_window_size + 1 + 2
|
419 |
+
clean_latent_4x_indices = torch.arange(index, index + 16) # 16
|
420 |
+
|
421 |
+
logger.info(
|
422 |
+
f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
|
423 |
+
)
|
424 |
+
|
425 |
+
# prepare conditioning inputs
|
426 |
+
llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
|
427 |
+
llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
|
428 |
+
clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
429 |
+
if cfg_scale == 1.0:
|
430 |
+
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
|
431 |
+
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
|
432 |
+
else:
|
433 |
+
llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
|
434 |
+
llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
|
435 |
+
clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
436 |
+
image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
|
437 |
+
device, dtype=torch.bfloat16
|
438 |
+
)
|
439 |
+
|
440 |
+
generated_latents = sample_hunyuan(
|
441 |
+
transformer=model,
|
442 |
+
sampler=args.sample_solver,
|
443 |
+
width=width,
|
444 |
+
height=height,
|
445 |
+
frames=1,
|
446 |
+
real_guidance_scale=cfg_scale,
|
447 |
+
distilled_guidance_scale=guidance_scale,
|
448 |
+
guidance_rescale=0.0,
|
449 |
+
# shift=3.0,
|
450 |
+
num_inference_steps=sample_steps,
|
451 |
+
generator=generator,
|
452 |
+
prompt_embeds=llama_vec,
|
453 |
+
prompt_embeds_mask=llama_attention_mask,
|
454 |
+
prompt_poolers=clip_l_pooler,
|
455 |
+
negative_prompt_embeds=llama_vec_n,
|
456 |
+
negative_prompt_embeds_mask=llama_attention_mask_n,
|
457 |
+
negative_prompt_poolers=clip_l_pooler_n,
|
458 |
+
device=device,
|
459 |
+
dtype=torch.bfloat16,
|
460 |
+
image_embeddings=image_encoder_last_hidden_state,
|
461 |
+
latent_indices=latent_indices,
|
462 |
+
clean_latents=clean_latents,
|
463 |
+
clean_latent_indices=clean_latent_indices,
|
464 |
+
clean_latents_2x=clean_latents_2x,
|
465 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
466 |
+
clean_latents_4x=clean_latents_4x,
|
467 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
468 |
+
)
|
469 |
+
|
470 |
+
real_history_latents = generated_latents.to(clean_latents)
|
471 |
+
|
472 |
+
# wait for 5 seconds until block swap is done
|
473 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
474 |
+
time.sleep(5)
|
475 |
+
|
476 |
+
gc.collect()
|
477 |
+
clean_memory_on_device(device)
|
478 |
+
|
479 |
+
video = decode_latent(
|
480 |
+
latent_window_size, total_latent_sections, args.bulk_decode, vae, real_history_latents, device, one_frame_mode
|
481 |
+
)
|
482 |
+
video = video.to("cpu", dtype=torch.float32).unsqueeze(0) # add batch dimension
|
483 |
+
video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
|
484 |
+
clean_memory_on_device(device)
|
485 |
+
|
486 |
+
return video
|
487 |
+
|
488 |
+
def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
|
489 |
+
vae_path = args.vae
|
490 |
+
logger.info(f"Loading VAE model from {vae_path}")
|
491 |
+
vae = load_framepack_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
|
492 |
+
return vae
|
493 |
+
|
494 |
+
def load_transformer(
|
495 |
+
self,
|
496 |
+
accelerator: Accelerator,
|
497 |
+
args: argparse.Namespace,
|
498 |
+
dit_path: str,
|
499 |
+
attn_mode: str,
|
500 |
+
split_attn: bool,
|
501 |
+
loading_device: str,
|
502 |
+
dit_weight_dtype: Optional[torch.dtype],
|
503 |
+
):
|
504 |
+
logger.info(f"Loading DiT model from {dit_path}")
|
505 |
+
device = accelerator.device
|
506 |
+
model = load_packed_model(device, dit_path, attn_mode, loading_device, args.fp8_scaled, split_attn)
|
507 |
+
return model
|
508 |
+
|
509 |
+
def scale_shift_latents(self, latents):
|
510 |
+
# FramePack VAE includes scaling
|
511 |
+
return latents
|
512 |
+
|
513 |
+
def call_dit(
|
514 |
+
self,
|
515 |
+
args: argparse.Namespace,
|
516 |
+
accelerator: Accelerator,
|
517 |
+
transformer,
|
518 |
+
latents: torch.Tensor,
|
519 |
+
batch: dict[str, torch.Tensor],
|
520 |
+
noise: torch.Tensor,
|
521 |
+
noisy_model_input: torch.Tensor,
|
522 |
+
timesteps: torch.Tensor,
|
523 |
+
network_dtype: torch.dtype,
|
524 |
+
):
|
525 |
+
model: HunyuanVideoTransformer3DModelPacked = transformer
|
526 |
+
device = accelerator.device
|
527 |
+
batch_size = latents.shape[0]
|
528 |
+
|
529 |
+
# maybe model.dtype is better than network_dtype...
|
530 |
+
distilled_guidance = torch.tensor([args.guidance_scale * 1000.0] * batch_size).to(device=device, dtype=network_dtype)
|
531 |
+
latents = latents.to(device=accelerator.device, dtype=network_dtype)
|
532 |
+
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
|
533 |
+
# for k, v in batch.items():
|
534 |
+
# if isinstance(v, torch.Tensor):
|
535 |
+
# print(f"{k}: {v.shape} {v.dtype} {v.device}")
|
536 |
+
with accelerator.autocast():
|
537 |
+
clean_latent_2x_indices = batch["clean_latent_2x_indices"] if "clean_latent_2x_indices" in batch else None
|
538 |
+
if clean_latent_2x_indices is not None:
|
539 |
+
clean_latent_2x = batch["latents_clean_2x"] if "latents_clean_2x" in batch else None
|
540 |
+
if clean_latent_2x is None:
|
541 |
+
clean_latent_2x = torch.zeros(
|
542 |
+
(batch_size, 16, 2, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
|
543 |
+
)
|
544 |
+
else:
|
545 |
+
clean_latent_2x = None
|
546 |
+
|
547 |
+
clean_latent_4x_indices = batch["clean_latent_4x_indices"] if "clean_latent_4x_indices" in batch else None
|
548 |
+
if clean_latent_4x_indices is not None:
|
549 |
+
clean_latent_4x = batch["latents_clean_4x"] if "latents_clean_4x" in batch else None
|
550 |
+
if clean_latent_4x is None:
|
551 |
+
clean_latent_4x = torch.zeros(
|
552 |
+
(batch_size, 16, 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
|
553 |
+
)
|
554 |
+
else:
|
555 |
+
clean_latent_4x = None
|
556 |
+
|
557 |
+
model_pred = model(
|
558 |
+
hidden_states=noisy_model_input,
|
559 |
+
timestep=timesteps,
|
560 |
+
encoder_hidden_states=batch["llama_vec"],
|
561 |
+
encoder_attention_mask=batch["llama_attention_mask"],
|
562 |
+
pooled_projections=batch["clip_l_pooler"],
|
563 |
+
guidance=distilled_guidance,
|
564 |
+
latent_indices=batch["latent_indices"],
|
565 |
+
clean_latents=batch["latents_clean"],
|
566 |
+
clean_latent_indices=batch["clean_latent_indices"],
|
567 |
+
clean_latents_2x=clean_latent_2x,
|
568 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
569 |
+
clean_latents_4x=clean_latent_4x,
|
570 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
571 |
+
image_embeddings=batch["image_embeddings"],
|
572 |
+
return_dict=False,
|
573 |
+
)
|
574 |
+
model_pred = model_pred[0] # returns tuple (model_pred, )
|
575 |
+
|
576 |
+
# flow matching loss
|
577 |
+
target = noise - latents
|
578 |
+
|
579 |
+
return model_pred, target
|
580 |
+
|
581 |
+
# endregion model specific
|
582 |
+
|
583 |
+
|
584 |
+
def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
585 |
+
"""FramePack specific parser setup"""
|
586 |
+
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
587 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
|
588 |
+
parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
|
589 |
+
parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
|
590 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
591 |
+
parser.add_argument(
|
592 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
593 |
+
)
|
594 |
+
parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
|
595 |
+
parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
|
596 |
+
parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once in sample generation")
|
597 |
+
parser.add_argument("--f1", action="store_true", help="Use F1 sampling method for sample generation")
|
598 |
+
parser.add_argument("--one_frame", action="store_true", help="Use one frame sampling method for sample generation")
|
599 |
+
return parser
|
600 |
+
|
601 |
+
|
602 |
+
if __name__ == "__main__":
|
603 |
+
parser = setup_parser_common()
|
604 |
+
parser = framepack_setup_parser(parser)
|
605 |
+
|
606 |
+
args = parser.parse_args()
|
607 |
+
args = read_config_from_file(args, parser)
|
608 |
+
|
609 |
+
assert (
|
610 |
+
args.vae_dtype is None or args.vae_dtype == "float16"
|
611 |
+
), "VAE dtype must be float16 / VAEのdtypeはfloat16でなければなりません"
|
612 |
+
args.vae_dtype = "float16" # fixed
|
613 |
+
args.dit_dtype = "bfloat16" # fixed
|
614 |
+
args.sample_solver = "unipc" # for sample generation, fixed to unipc
|
615 |
+
|
616 |
+
trainer = FramePackNetworkTrainer()
|
617 |
+
trainer.train(args)
|
hv_train.py
ADDED
@@ -0,0 +1,1721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import asyncio
|
3 |
+
from datetime import timedelta
|
4 |
+
import gc
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import pathlib
|
10 |
+
import re
|
11 |
+
import sys
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import json
|
15 |
+
from multiprocessing import Value
|
16 |
+
from typing import Any, Dict, List, Optional
|
17 |
+
import accelerate
|
18 |
+
import numpy as np
|
19 |
+
from packaging.version import Version
|
20 |
+
|
21 |
+
import huggingface_hub
|
22 |
+
import toml
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from tqdm import tqdm
|
26 |
+
from accelerate.utils import set_seed
|
27 |
+
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
28 |
+
from safetensors.torch import load_file, save_file
|
29 |
+
import transformers
|
30 |
+
from diffusers.optimization import (
|
31 |
+
SchedulerType as DiffusersSchedulerType,
|
32 |
+
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
|
33 |
+
)
|
34 |
+
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
35 |
+
|
36 |
+
from dataset import config_utils
|
37 |
+
from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
|
38 |
+
import hunyuan_model.text_encoder as text_encoder_module
|
39 |
+
from hunyuan_model.vae import load_vae
|
40 |
+
import hunyuan_model.vae as vae_module
|
41 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
42 |
+
import networks.lora as lora_module
|
43 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
44 |
+
from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO
|
45 |
+
|
46 |
+
import logging
|
47 |
+
|
48 |
+
from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
|
49 |
+
|
50 |
+
logger = logging.getLogger(__name__)
|
51 |
+
logging.basicConfig(level=logging.INFO)
|
52 |
+
|
53 |
+
|
54 |
+
BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
|
55 |
+
|
56 |
+
# TODO make separate file for some functions to commonize with other scripts
|
57 |
+
|
58 |
+
|
59 |
+
def clean_memory_on_device(device: torch.device):
|
60 |
+
r"""
|
61 |
+
Clean memory on the specified device, will be called from training scripts.
|
62 |
+
"""
|
63 |
+
gc.collect()
|
64 |
+
|
65 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
66 |
+
if device.type == "cuda":
|
67 |
+
torch.cuda.empty_cache()
|
68 |
+
if device.type == "xpu":
|
69 |
+
torch.xpu.empty_cache()
|
70 |
+
if device.type == "mps":
|
71 |
+
torch.mps.empty_cache()
|
72 |
+
|
73 |
+
|
74 |
+
# for collate_fn: epoch and step is multiprocessing.Value
|
75 |
+
class collator_class:
|
76 |
+
def __init__(self, epoch, step, dataset):
|
77 |
+
self.current_epoch = epoch
|
78 |
+
self.current_step = step
|
79 |
+
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
80 |
+
|
81 |
+
def __call__(self, examples):
|
82 |
+
worker_info = torch.utils.data.get_worker_info()
|
83 |
+
# worker_info is None in the main process
|
84 |
+
if worker_info is not None:
|
85 |
+
dataset = worker_info.dataset
|
86 |
+
else:
|
87 |
+
dataset = self.dataset
|
88 |
+
|
89 |
+
# set epoch and step
|
90 |
+
dataset.set_current_epoch(self.current_epoch.value)
|
91 |
+
dataset.set_current_step(self.current_step.value)
|
92 |
+
return examples[0]
|
93 |
+
|
94 |
+
|
95 |
+
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
|
96 |
+
"""
|
97 |
+
DeepSpeed is not supported in this script currently.
|
98 |
+
"""
|
99 |
+
if args.logging_dir is None:
|
100 |
+
logging_dir = None
|
101 |
+
else:
|
102 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
103 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
104 |
+
|
105 |
+
if args.log_with is None:
|
106 |
+
if logging_dir is not None:
|
107 |
+
log_with = "tensorboard"
|
108 |
+
else:
|
109 |
+
log_with = None
|
110 |
+
else:
|
111 |
+
log_with = args.log_with
|
112 |
+
if log_with in ["tensorboard", "all"]:
|
113 |
+
if logging_dir is None:
|
114 |
+
raise ValueError(
|
115 |
+
"logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
|
116 |
+
)
|
117 |
+
if log_with in ["wandb", "all"]:
|
118 |
+
try:
|
119 |
+
import wandb
|
120 |
+
except ImportError:
|
121 |
+
raise ImportError("No wandb / wandb がインストールされていないようです")
|
122 |
+
if logging_dir is not None:
|
123 |
+
os.makedirs(logging_dir, exist_ok=True)
|
124 |
+
os.environ["WANDB_DIR"] = logging_dir
|
125 |
+
if args.wandb_api_key is not None:
|
126 |
+
wandb.login(key=args.wandb_api_key)
|
127 |
+
|
128 |
+
kwargs_handlers = [
|
129 |
+
(
|
130 |
+
InitProcessGroupKwargs(
|
131 |
+
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
132 |
+
init_method=(
|
133 |
+
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
|
134 |
+
),
|
135 |
+
timeout=timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
|
136 |
+
)
|
137 |
+
if torch.cuda.device_count() > 1
|
138 |
+
else None
|
139 |
+
),
|
140 |
+
(
|
141 |
+
DistributedDataParallelKwargs(
|
142 |
+
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
|
143 |
+
)
|
144 |
+
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
145 |
+
else None
|
146 |
+
),
|
147 |
+
]
|
148 |
+
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
149 |
+
|
150 |
+
accelerator = Accelerator(
|
151 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
152 |
+
mixed_precision=args.mixed_precision,
|
153 |
+
log_with=log_with,
|
154 |
+
project_dir=logging_dir,
|
155 |
+
kwargs_handlers=kwargs_handlers,
|
156 |
+
)
|
157 |
+
print("accelerator device:", accelerator.device)
|
158 |
+
return accelerator
|
159 |
+
|
160 |
+
|
161 |
+
def line_to_prompt_dict(line: str) -> dict:
|
162 |
+
# subset of gen_img_diffusers
|
163 |
+
prompt_args = line.split(" --")
|
164 |
+
prompt_dict = {}
|
165 |
+
prompt_dict["prompt"] = prompt_args[0]
|
166 |
+
|
167 |
+
for parg in prompt_args:
|
168 |
+
try:
|
169 |
+
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
170 |
+
if m:
|
171 |
+
prompt_dict["width"] = int(m.group(1))
|
172 |
+
continue
|
173 |
+
|
174 |
+
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
175 |
+
if m:
|
176 |
+
prompt_dict["height"] = int(m.group(1))
|
177 |
+
continue
|
178 |
+
|
179 |
+
m = re.match(r"f (\d+)", parg, re.IGNORECASE)
|
180 |
+
if m:
|
181 |
+
prompt_dict["frame_count"] = int(m.group(1))
|
182 |
+
continue
|
183 |
+
|
184 |
+
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
185 |
+
if m:
|
186 |
+
prompt_dict["seed"] = int(m.group(1))
|
187 |
+
continue
|
188 |
+
|
189 |
+
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
190 |
+
if m: # steps
|
191 |
+
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
|
192 |
+
continue
|
193 |
+
|
194 |
+
# m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
195 |
+
# if m: # scale
|
196 |
+
# prompt_dict["scale"] = float(m.group(1))
|
197 |
+
# continue
|
198 |
+
# m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
199 |
+
# if m: # negative prompt
|
200 |
+
# prompt_dict["negative_prompt"] = m.group(1)
|
201 |
+
# continue
|
202 |
+
|
203 |
+
except ValueError as ex:
|
204 |
+
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
205 |
+
logger.error(ex)
|
206 |
+
|
207 |
+
return prompt_dict
|
208 |
+
|
209 |
+
|
210 |
+
def load_prompts(prompt_file: str) -> list[Dict]:
|
211 |
+
# read prompts
|
212 |
+
if prompt_file.endswith(".txt"):
|
213 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
214 |
+
lines = f.readlines()
|
215 |
+
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
216 |
+
elif prompt_file.endswith(".toml"):
|
217 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
218 |
+
data = toml.load(f)
|
219 |
+
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
220 |
+
elif prompt_file.endswith(".json"):
|
221 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
222 |
+
prompts = json.load(f)
|
223 |
+
|
224 |
+
# preprocess prompts
|
225 |
+
for i in range(len(prompts)):
|
226 |
+
prompt_dict = prompts[i]
|
227 |
+
if isinstance(prompt_dict, str):
|
228 |
+
prompt_dict = line_to_prompt_dict(prompt_dict)
|
229 |
+
prompts[i] = prompt_dict
|
230 |
+
assert isinstance(prompt_dict, dict)
|
231 |
+
|
232 |
+
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
233 |
+
prompt_dict["enum"] = i
|
234 |
+
prompt_dict.pop("subset", None)
|
235 |
+
|
236 |
+
return prompts
|
237 |
+
|
238 |
+
|
239 |
+
def compute_density_for_timestep_sampling(
|
240 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
241 |
+
):
|
242 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
243 |
+
|
244 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
245 |
+
|
246 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
247 |
+
"""
|
248 |
+
if weighting_scheme == "logit_normal":
|
249 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
250 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
251 |
+
u = torch.nn.functional.sigmoid(u)
|
252 |
+
elif weighting_scheme == "mode":
|
253 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
254 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
255 |
+
else:
|
256 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
257 |
+
return u
|
258 |
+
|
259 |
+
|
260 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
261 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
262 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
263 |
+
timesteps = timesteps.to(device)
|
264 |
+
|
265 |
+
# if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
|
266 |
+
if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
|
267 |
+
# raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
|
268 |
+
# round to nearest timestep
|
269 |
+
logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
|
270 |
+
step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
|
271 |
+
else:
|
272 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
273 |
+
|
274 |
+
sigma = sigmas[step_indices].flatten()
|
275 |
+
while len(sigma.shape) < n_dim:
|
276 |
+
sigma = sigma.unsqueeze(-1)
|
277 |
+
return sigma
|
278 |
+
|
279 |
+
|
280 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
|
281 |
+
"""Computes loss weighting scheme for SD3 training.
|
282 |
+
|
283 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
284 |
+
|
285 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
286 |
+
"""
|
287 |
+
if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
|
288 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
|
289 |
+
if weighting_scheme == "sigma_sqrt":
|
290 |
+
weighting = (sigmas**-2.0).float()
|
291 |
+
else:
|
292 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
293 |
+
weighting = 2 / (math.pi * bot)
|
294 |
+
else:
|
295 |
+
weighting = None # torch.ones_like(sigmas)
|
296 |
+
return weighting
|
297 |
+
|
298 |
+
|
299 |
+
class FineTuningTrainer:
|
300 |
+
def __init__(self):
|
301 |
+
pass
|
302 |
+
|
303 |
+
def process_sample_prompts(
|
304 |
+
self,
|
305 |
+
args: argparse.Namespace,
|
306 |
+
accelerator: Accelerator,
|
307 |
+
sample_prompts: str,
|
308 |
+
text_encoder1: str,
|
309 |
+
text_encoder2: str,
|
310 |
+
fp8_llm: bool,
|
311 |
+
):
|
312 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
|
313 |
+
prompts = load_prompts(sample_prompts)
|
314 |
+
|
315 |
+
def encode_for_text_encoder(text_encoder, is_llm=True):
|
316 |
+
sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
|
317 |
+
with accelerator.autocast(), torch.no_grad():
|
318 |
+
for prompt_dict in prompts:
|
319 |
+
for p in [prompt_dict.get("prompt", "")]:
|
320 |
+
if p not in sample_prompts_te_outputs:
|
321 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
322 |
+
|
323 |
+
data_type = "video"
|
324 |
+
text_inputs = text_encoder.text2tokens(p, data_type=data_type)
|
325 |
+
|
326 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
327 |
+
sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
|
328 |
+
|
329 |
+
return sample_prompts_te_outputs
|
330 |
+
|
331 |
+
# Load Text Encoder 1 and encode
|
332 |
+
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
|
333 |
+
logger.info(f"loading text encoder 1: {text_encoder1}")
|
334 |
+
text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
|
335 |
+
|
336 |
+
logger.info("encoding with Text Encoder 1")
|
337 |
+
te_outputs_1 = encode_for_text_encoder(text_encoder_1)
|
338 |
+
del text_encoder_1
|
339 |
+
|
340 |
+
# Load Text Encoder 2 and encode
|
341 |
+
logger.info(f"loading text encoder 2: {text_encoder2}")
|
342 |
+
text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
|
343 |
+
|
344 |
+
logger.info("encoding with Text Encoder 2")
|
345 |
+
te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
|
346 |
+
del text_encoder_2
|
347 |
+
|
348 |
+
# prepare sample parameters
|
349 |
+
sample_parameters = []
|
350 |
+
for prompt_dict in prompts:
|
351 |
+
prompt_dict_copy = prompt_dict.copy()
|
352 |
+
p = prompt_dict.get("prompt", "")
|
353 |
+
prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
|
354 |
+
prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
|
355 |
+
prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
|
356 |
+
prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
|
357 |
+
sample_parameters.append(prompt_dict_copy)
|
358 |
+
|
359 |
+
clean_memory_on_device(accelerator.device)
|
360 |
+
|
361 |
+
return sample_parameters
|
362 |
+
|
363 |
+
def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
|
364 |
+
# adamw, adamw8bit, adafactor
|
365 |
+
|
366 |
+
optimizer_type = args.optimizer_type.lower()
|
367 |
+
|
368 |
+
# split optimizer_type and optimizer_args
|
369 |
+
optimizer_kwargs = {}
|
370 |
+
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
371 |
+
for arg in args.optimizer_args:
|
372 |
+
key, value = arg.split("=")
|
373 |
+
value = ast.literal_eval(value)
|
374 |
+
optimizer_kwargs[key] = value
|
375 |
+
|
376 |
+
lr = args.learning_rate
|
377 |
+
optimizer = None
|
378 |
+
optimizer_class = None
|
379 |
+
|
380 |
+
if optimizer_type.endswith("8bit".lower()):
|
381 |
+
try:
|
382 |
+
import bitsandbytes as bnb
|
383 |
+
except ImportError:
|
384 |
+
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
385 |
+
|
386 |
+
if optimizer_type == "AdamW8bit".lower():
|
387 |
+
logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
388 |
+
optimizer_class = bnb.optim.AdamW8bit
|
389 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
390 |
+
|
391 |
+
elif optimizer_type == "Adafactor".lower():
|
392 |
+
# Adafactor: check relative_step and warmup_init
|
393 |
+
if "relative_step" not in optimizer_kwargs:
|
394 |
+
optimizer_kwargs["relative_step"] = True # default
|
395 |
+
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
396 |
+
logger.info(
|
397 |
+
f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
|
398 |
+
)
|
399 |
+
optimizer_kwargs["relative_step"] = True
|
400 |
+
logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
|
401 |
+
|
402 |
+
if optimizer_kwargs["relative_step"]:
|
403 |
+
logger.info(f"relative_step is true / relative_stepがtrueです")
|
404 |
+
if lr != 0.0:
|
405 |
+
logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
406 |
+
args.learning_rate = None
|
407 |
+
|
408 |
+
if args.lr_scheduler != "adafactor":
|
409 |
+
logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
410 |
+
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
411 |
+
|
412 |
+
lr = None
|
413 |
+
else:
|
414 |
+
if args.max_grad_norm != 0.0:
|
415 |
+
logger.warning(
|
416 |
+
f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
|
417 |
+
)
|
418 |
+
if args.lr_scheduler != "constant_with_warmup":
|
419 |
+
logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
420 |
+
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
421 |
+
logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
422 |
+
|
423 |
+
optimizer_class = transformers.optimization.Adafactor
|
424 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
425 |
+
|
426 |
+
elif optimizer_type == "AdamW".lower():
|
427 |
+
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
|
428 |
+
optimizer_class = torch.optim.AdamW
|
429 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
430 |
+
|
431 |
+
if optimizer is None:
|
432 |
+
# 任意のoptimizerを使う
|
433 |
+
case_sensitive_optimizer_type = args.optimizer_type # not lower
|
434 |
+
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
|
435 |
+
|
436 |
+
if "." not in case_sensitive_optimizer_type: # from torch.optim
|
437 |
+
optimizer_module = torch.optim
|
438 |
+
else: # from other library
|
439 |
+
values = case_sensitive_optimizer_type.split(".")
|
440 |
+
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
441 |
+
case_sensitive_optimizer_type = values[-1]
|
442 |
+
|
443 |
+
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
|
444 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
445 |
+
|
446 |
+
# for logging
|
447 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
448 |
+
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
449 |
+
|
450 |
+
# get train and eval functions
|
451 |
+
if hasattr(optimizer, "train") and callable(optimizer.train):
|
452 |
+
train_fn = optimizer.train
|
453 |
+
eval_fn = optimizer.eval
|
454 |
+
else:
|
455 |
+
train_fn = lambda: None
|
456 |
+
eval_fn = lambda: None
|
457 |
+
|
458 |
+
return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
|
459 |
+
|
460 |
+
def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
|
461 |
+
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
462 |
+
|
463 |
+
def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
|
464 |
+
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
|
465 |
+
# this scheduler is used for logging only.
|
466 |
+
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
|
467 |
+
class DummyScheduler:
|
468 |
+
def __init__(self, optimizer: torch.optim.Optimizer):
|
469 |
+
self.optimizer = optimizer
|
470 |
+
|
471 |
+
def step(self):
|
472 |
+
pass
|
473 |
+
|
474 |
+
def get_last_lr(self):
|
475 |
+
return [group["lr"] for group in self.optimizer.param_groups]
|
476 |
+
|
477 |
+
return DummyScheduler(optimizer)
|
478 |
+
|
479 |
+
def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
|
480 |
+
"""
|
481 |
+
Unified API to get any scheduler from its name.
|
482 |
+
"""
|
483 |
+
# if schedulefree optimizer, return dummy scheduler
|
484 |
+
if self.is_schedulefree_optimizer(optimizer, args):
|
485 |
+
return self.get_dummy_scheduler(optimizer)
|
486 |
+
|
487 |
+
name = args.lr_scheduler
|
488 |
+
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
489 |
+
num_warmup_steps: Optional[int] = (
|
490 |
+
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
|
491 |
+
)
|
492 |
+
num_decay_steps: Optional[int] = (
|
493 |
+
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
|
494 |
+
)
|
495 |
+
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
|
496 |
+
num_cycles = args.lr_scheduler_num_cycles
|
497 |
+
power = args.lr_scheduler_power
|
498 |
+
timescale = args.lr_scheduler_timescale
|
499 |
+
min_lr_ratio = args.lr_scheduler_min_lr_ratio
|
500 |
+
|
501 |
+
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
502 |
+
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
503 |
+
for arg in args.lr_scheduler_args:
|
504 |
+
key, value = arg.split("=")
|
505 |
+
value = ast.literal_eval(value)
|
506 |
+
lr_scheduler_kwargs[key] = value
|
507 |
+
|
508 |
+
def wrap_check_needless_num_warmup_steps(return_vals):
|
509 |
+
if num_warmup_steps is not None and num_warmup_steps != 0:
|
510 |
+
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
|
511 |
+
return return_vals
|
512 |
+
|
513 |
+
# using any lr_scheduler from other library
|
514 |
+
if args.lr_scheduler_type:
|
515 |
+
lr_scheduler_type = args.lr_scheduler_type
|
516 |
+
logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
|
517 |
+
if "." not in lr_scheduler_type: # default to use torch.optim
|
518 |
+
lr_scheduler_module = torch.optim.lr_scheduler
|
519 |
+
else:
|
520 |
+
values = lr_scheduler_type.split(".")
|
521 |
+
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
|
522 |
+
lr_scheduler_type = values[-1]
|
523 |
+
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
524 |
+
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
525 |
+
return lr_scheduler
|
526 |
+
|
527 |
+
if name.startswith("adafactor"):
|
528 |
+
assert (
|
529 |
+
type(optimizer) == transformers.optimization.Adafactor
|
530 |
+
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
531 |
+
initial_lr = float(name.split(":")[1])
|
532 |
+
# logger.info(f"adafactor scheduler init lr {initial_lr}")
|
533 |
+
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
534 |
+
|
535 |
+
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
|
536 |
+
name = DiffusersSchedulerType(name)
|
537 |
+
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
|
538 |
+
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
539 |
+
|
540 |
+
name = SchedulerType(name)
|
541 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
542 |
+
|
543 |
+
if name == SchedulerType.CONSTANT:
|
544 |
+
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
545 |
+
|
546 |
+
# All other schedulers require `num_warmup_steps`
|
547 |
+
if num_warmup_steps is None:
|
548 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
549 |
+
|
550 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
551 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
552 |
+
|
553 |
+
if name == SchedulerType.INVERSE_SQRT:
|
554 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
|
555 |
+
|
556 |
+
# All other schedulers require `num_training_steps`
|
557 |
+
if num_training_steps is None:
|
558 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
559 |
+
|
560 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
561 |
+
return schedule_func(
|
562 |
+
optimizer,
|
563 |
+
num_warmup_steps=num_warmup_steps,
|
564 |
+
num_training_steps=num_training_steps,
|
565 |
+
num_cycles=num_cycles,
|
566 |
+
**lr_scheduler_kwargs,
|
567 |
+
)
|
568 |
+
|
569 |
+
if name == SchedulerType.POLYNOMIAL:
|
570 |
+
return schedule_func(
|
571 |
+
optimizer,
|
572 |
+
num_warmup_steps=num_warmup_steps,
|
573 |
+
num_training_steps=num_training_steps,
|
574 |
+
power=power,
|
575 |
+
**lr_scheduler_kwargs,
|
576 |
+
)
|
577 |
+
|
578 |
+
if name == SchedulerType.COSINE_WITH_MIN_LR:
|
579 |
+
return schedule_func(
|
580 |
+
optimizer,
|
581 |
+
num_warmup_steps=num_warmup_steps,
|
582 |
+
num_training_steps=num_training_steps,
|
583 |
+
num_cycles=num_cycles / 2,
|
584 |
+
min_lr_rate=min_lr_ratio,
|
585 |
+
**lr_scheduler_kwargs,
|
586 |
+
)
|
587 |
+
|
588 |
+
# these schedulers do not require `num_decay_steps`
|
589 |
+
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
|
590 |
+
return schedule_func(
|
591 |
+
optimizer,
|
592 |
+
num_warmup_steps=num_warmup_steps,
|
593 |
+
num_training_steps=num_training_steps,
|
594 |
+
**lr_scheduler_kwargs,
|
595 |
+
)
|
596 |
+
|
597 |
+
# All other schedulers require `num_decay_steps`
|
598 |
+
if num_decay_steps is None:
|
599 |
+
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
600 |
+
if name == SchedulerType.WARMUP_STABLE_DECAY:
|
601 |
+
return schedule_func(
|
602 |
+
optimizer,
|
603 |
+
num_warmup_steps=num_warmup_steps,
|
604 |
+
num_stable_steps=num_stable_steps,
|
605 |
+
num_decay_steps=num_decay_steps,
|
606 |
+
num_cycles=num_cycles / 2,
|
607 |
+
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
|
608 |
+
**lr_scheduler_kwargs,
|
609 |
+
)
|
610 |
+
|
611 |
+
return schedule_func(
|
612 |
+
optimizer,
|
613 |
+
num_warmup_steps=num_warmup_steps,
|
614 |
+
num_training_steps=num_training_steps,
|
615 |
+
num_decay_steps=num_decay_steps,
|
616 |
+
**lr_scheduler_kwargs,
|
617 |
+
)
|
618 |
+
|
619 |
+
def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
|
620 |
+
if not args.resume:
|
621 |
+
return False
|
622 |
+
|
623 |
+
if not args.resume_from_huggingface:
|
624 |
+
logger.info(f"resume training from local state: {args.resume}")
|
625 |
+
accelerator.load_state(args.resume)
|
626 |
+
return True
|
627 |
+
|
628 |
+
logger.info(f"resume training from huggingface state: {args.resume}")
|
629 |
+
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
630 |
+
path_in_repo = "/".join(args.resume.split("/")[2:])
|
631 |
+
revision = None
|
632 |
+
repo_type = None
|
633 |
+
if ":" in path_in_repo:
|
634 |
+
divided = path_in_repo.split(":")
|
635 |
+
if len(divided) == 2:
|
636 |
+
path_in_repo, revision = divided
|
637 |
+
repo_type = "model"
|
638 |
+
else:
|
639 |
+
path_in_repo, revision, repo_type = divided
|
640 |
+
logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
|
641 |
+
|
642 |
+
list_files = huggingface_utils.list_dir(
|
643 |
+
repo_id=repo_id,
|
644 |
+
subfolder=path_in_repo,
|
645 |
+
revision=revision,
|
646 |
+
token=args.huggingface_token,
|
647 |
+
repo_type=repo_type,
|
648 |
+
)
|
649 |
+
|
650 |
+
async def download(filename) -> str:
|
651 |
+
def task():
|
652 |
+
return huggingface_hub.hf_hub_download(
|
653 |
+
repo_id=repo_id,
|
654 |
+
filename=filename,
|
655 |
+
revision=revision,
|
656 |
+
repo_type=repo_type,
|
657 |
+
token=args.huggingface_token,
|
658 |
+
)
|
659 |
+
|
660 |
+
return await asyncio.get_event_loop().run_in_executor(None, task)
|
661 |
+
|
662 |
+
loop = asyncio.get_event_loop()
|
663 |
+
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
|
664 |
+
if len(results) == 0:
|
665 |
+
raise ValueError(
|
666 |
+
"No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
|
667 |
+
)
|
668 |
+
dirname = os.path.dirname(results[0])
|
669 |
+
accelerator.load_state(dirname)
|
670 |
+
|
671 |
+
return True
|
672 |
+
|
673 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
|
674 |
+
pass
|
675 |
+
|
676 |
+
def get_noisy_model_input_and_timesteps(
|
677 |
+
self,
|
678 |
+
args: argparse.Namespace,
|
679 |
+
noise: torch.Tensor,
|
680 |
+
latents: torch.Tensor,
|
681 |
+
noise_scheduler: FlowMatchDiscreteScheduler,
|
682 |
+
device: torch.device,
|
683 |
+
dtype: torch.dtype,
|
684 |
+
):
|
685 |
+
batch_size = noise.shape[0]
|
686 |
+
|
687 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
|
688 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
689 |
+
# Simple random t-based noise sampling
|
690 |
+
if args.timestep_sampling == "sigmoid":
|
691 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
|
692 |
+
else:
|
693 |
+
t = torch.rand((batch_size,), device=device)
|
694 |
+
|
695 |
+
elif args.timestep_sampling == "shift":
|
696 |
+
shift = args.discrete_flow_shift
|
697 |
+
logits_norm = torch.randn(batch_size, device=device)
|
698 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
699 |
+
t = logits_norm.sigmoid()
|
700 |
+
t = (t * shift) / (1 + (shift - 1) * t)
|
701 |
+
|
702 |
+
t_min = args.min_timestep if args.min_timestep is not None else 0
|
703 |
+
t_max = args.max_timestep if args.max_timestep is not None else 1000.0
|
704 |
+
t_min /= 1000.0
|
705 |
+
t_max /= 1000.0
|
706 |
+
t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
|
707 |
+
|
708 |
+
timesteps = t * 1000.0
|
709 |
+
t = t.view(-1, 1, 1, 1, 1)
|
710 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
711 |
+
|
712 |
+
timesteps += 1 # 1 to 1000
|
713 |
+
else:
|
714 |
+
# Sample a random timestep for each image
|
715 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
716 |
+
u = compute_density_for_timestep_sampling(
|
717 |
+
weighting_scheme=args.weighting_scheme,
|
718 |
+
batch_size=batch_size,
|
719 |
+
logit_mean=args.logit_mean,
|
720 |
+
logit_std=args.logit_std,
|
721 |
+
mode_scale=args.mode_scale,
|
722 |
+
)
|
723 |
+
# indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
724 |
+
t_min = args.min_timestep if args.min_timestep is not None else 0
|
725 |
+
t_max = args.max_timestep if args.max_timestep is not None else 1000
|
726 |
+
indices = (u * (t_max - t_min) + t_min).long()
|
727 |
+
|
728 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
|
729 |
+
|
730 |
+
# Add noise according to flow matching.
|
731 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
732 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
733 |
+
|
734 |
+
return noisy_model_input, timesteps
|
735 |
+
|
736 |
+
def train(self, args):
|
737 |
+
if args.seed is None:
|
738 |
+
args.seed = random.randint(0, 2**32)
|
739 |
+
set_seed(args.seed)
|
740 |
+
|
741 |
+
# Load dataset config
|
742 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
743 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
744 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
745 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
746 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
|
747 |
+
|
748 |
+
current_epoch = Value("i", 0)
|
749 |
+
current_step = Value("i", 0)
|
750 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
751 |
+
collator = collator_class(current_epoch, current_step, ds_for_collator)
|
752 |
+
|
753 |
+
# prepare accelerator
|
754 |
+
logger.info("preparing accelerator")
|
755 |
+
accelerator = prepare_accelerator(args)
|
756 |
+
is_main_process = accelerator.is_main_process
|
757 |
+
|
758 |
+
# prepare dtype
|
759 |
+
weight_dtype = torch.float32
|
760 |
+
if args.mixed_precision == "fp16":
|
761 |
+
weight_dtype = torch.float16
|
762 |
+
elif args.mixed_precision == "bf16":
|
763 |
+
weight_dtype = torch.bfloat16
|
764 |
+
|
765 |
+
# HunyuanVideo specific
|
766 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
|
767 |
+
|
768 |
+
# get embedding for sampling images
|
769 |
+
sample_parameters = vae = None
|
770 |
+
if args.sample_prompts:
|
771 |
+
sample_parameters = self.process_sample_prompts(
|
772 |
+
args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
|
773 |
+
)
|
774 |
+
|
775 |
+
# Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
|
776 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
|
777 |
+
vae.requires_grad_(False)
|
778 |
+
vae.eval()
|
779 |
+
|
780 |
+
if args.vae_chunk_size is not None:
|
781 |
+
vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
|
782 |
+
logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
|
783 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
784 |
+
vae.enable_spatial_tiling(True)
|
785 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
786 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
787 |
+
elif args.vae_tiling:
|
788 |
+
vae.enable_spatial_tiling(True)
|
789 |
+
|
790 |
+
# load DiT model
|
791 |
+
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
|
792 |
+
loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
|
793 |
+
|
794 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
795 |
+
if args.sdpa:
|
796 |
+
attn_mode = "torch"
|
797 |
+
elif args.flash_attn:
|
798 |
+
attn_mode = "flash"
|
799 |
+
elif args.sage_attn:
|
800 |
+
attn_mode = "sageattn"
|
801 |
+
elif args.xformers:
|
802 |
+
attn_mode = "xformers"
|
803 |
+
else:
|
804 |
+
raise ValueError(
|
805 |
+
f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください"
|
806 |
+
)
|
807 |
+
transformer = load_transformer(
|
808 |
+
args.dit, attn_mode, args.split_attn, loading_device, None, in_channels=args.dit_in_channels
|
809 |
+
) # load as is
|
810 |
+
|
811 |
+
if blocks_to_swap > 0:
|
812 |
+
logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
|
813 |
+
transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
|
814 |
+
transformer.move_to_device_except_swap_blocks(accelerator.device)
|
815 |
+
if args.img_in_txt_in_offloading:
|
816 |
+
logger.info("Enable offloading img_in and txt_in to CPU")
|
817 |
+
transformer.enable_img_in_txt_in_offloading()
|
818 |
+
|
819 |
+
if args.gradient_checkpointing:
|
820 |
+
transformer.enable_gradient_checkpointing()
|
821 |
+
|
822 |
+
# prepare optimizer, data loader etc.
|
823 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
824 |
+
|
825 |
+
transformer.requires_grad_(False)
|
826 |
+
if accelerator.is_main_process:
|
827 |
+
accelerator.print(f"Trainable modules '{args.trainable_modules}'.")
|
828 |
+
for name, param in transformer.named_parameters():
|
829 |
+
for trainable_module_name in args.trainable_modules:
|
830 |
+
if trainable_module_name in name:
|
831 |
+
param.requires_grad = True
|
832 |
+
break
|
833 |
+
|
834 |
+
total_params = list(transformer.parameters())
|
835 |
+
trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
836 |
+
logger.info(
|
837 |
+
f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M"
|
838 |
+
)
|
839 |
+
optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
|
840 |
+
args, trainable_params
|
841 |
+
)
|
842 |
+
|
843 |
+
# prepare dataloader
|
844 |
+
|
845 |
+
# num workers for data loader: if 0, persistent_workers is not available
|
846 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
847 |
+
|
848 |
+
train_dataloader = torch.utils.data.DataLoader(
|
849 |
+
train_dataset_group,
|
850 |
+
batch_size=1,
|
851 |
+
shuffle=True,
|
852 |
+
collate_fn=collator,
|
853 |
+
num_workers=n_workers,
|
854 |
+
persistent_workers=args.persistent_data_loader_workers,
|
855 |
+
)
|
856 |
+
|
857 |
+
# calculate max_train_steps
|
858 |
+
if args.max_train_epochs is not None:
|
859 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
860 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
861 |
+
)
|
862 |
+
accelerator.print(
|
863 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
864 |
+
)
|
865 |
+
|
866 |
+
# send max_train_steps to train_dataset_group
|
867 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
868 |
+
|
869 |
+
# prepare lr_scheduler
|
870 |
+
lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
|
871 |
+
|
872 |
+
# prepare training model. accelerator does some magic here
|
873 |
+
|
874 |
+
# experimental feature: train the model with gradients in fp16/bf16
|
875 |
+
dit_dtype = torch.float32
|
876 |
+
if args.full_fp16:
|
877 |
+
assert (
|
878 |
+
args.mixed_precision == "fp16"
|
879 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
880 |
+
accelerator.print("enable full fp16 training.")
|
881 |
+
dit_weight_dtype = torch.float16
|
882 |
+
elif args.full_bf16:
|
883 |
+
assert (
|
884 |
+
args.mixed_precision == "bf16"
|
885 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
886 |
+
accelerator.print("enable full bf16 training.")
|
887 |
+
dit_weight_dtype = torch.bfloat16
|
888 |
+
else:
|
889 |
+
dit_weight_dtype = torch.float32
|
890 |
+
|
891 |
+
# TODO add fused optimizer and stochastic rounding
|
892 |
+
|
893 |
+
# cast model to dit_weight_dtype
|
894 |
+
# if dit_dtype != dit_weight_dtype:
|
895 |
+
logger.info(f"casting model to {dit_weight_dtype}")
|
896 |
+
transformer.to(dit_weight_dtype)
|
897 |
+
|
898 |
+
if blocks_to_swap > 0:
|
899 |
+
transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
|
900 |
+
accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
901 |
+
accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
|
902 |
+
else:
|
903 |
+
transformer = accelerator.prepare(transformer)
|
904 |
+
|
905 |
+
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
906 |
+
|
907 |
+
transformer.train()
|
908 |
+
|
909 |
+
if args.full_fp16:
|
910 |
+
# patch accelerator for fp16 training
|
911 |
+
# def patch_accelerator_for_fp16_training(accelerator):
|
912 |
+
org_unscale_grads = accelerator.scaler._unscale_grads_
|
913 |
+
|
914 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
915 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
916 |
+
|
917 |
+
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
918 |
+
|
919 |
+
# resume from local or huggingface. accelerator.step is set
|
920 |
+
self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
|
921 |
+
|
922 |
+
# epoch数を計算する
|
923 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
924 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
925 |
+
|
926 |
+
# 学習���る
|
927 |
+
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
928 |
+
|
929 |
+
accelerator.print("running training / 学習開始")
|
930 |
+
accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
|
931 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
932 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
933 |
+
accelerator.print(
|
934 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
935 |
+
)
|
936 |
+
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
937 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
938 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
939 |
+
|
940 |
+
if accelerator.is_main_process:
|
941 |
+
init_kwargs = {}
|
942 |
+
if args.wandb_run_name:
|
943 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
944 |
+
if args.log_tracker_config is not None:
|
945 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
946 |
+
accelerator.init_trackers(
|
947 |
+
"hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name,
|
948 |
+
config=train_utils.get_sanitized_config_or_none(args),
|
949 |
+
init_kwargs=init_kwargs,
|
950 |
+
)
|
951 |
+
|
952 |
+
# TODO skip until initial step
|
953 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
954 |
+
|
955 |
+
epoch_to_start = 0
|
956 |
+
global_step = 0
|
957 |
+
noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
|
958 |
+
|
959 |
+
loss_recorder = train_utils.LossRecorder()
|
960 |
+
del train_dataset_group
|
961 |
+
|
962 |
+
# function for saving/removing
|
963 |
+
def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
964 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
965 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
966 |
+
|
967 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
968 |
+
|
969 |
+
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
970 |
+
if args.min_timestep is not None or args.max_timestep is not None:
|
971 |
+
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
972 |
+
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
973 |
+
md_timesteps = (min_time_step, max_time_step)
|
974 |
+
else:
|
975 |
+
md_timesteps = None
|
976 |
+
|
977 |
+
sai_metadata = sai_model_spec.build_metadata(
|
978 |
+
None,
|
979 |
+
ARCHITECTURE_HUNYUAN_VIDEO,
|
980 |
+
time.time(),
|
981 |
+
title,
|
982 |
+
None,
|
983 |
+
args.metadata_author,
|
984 |
+
args.metadata_description,
|
985 |
+
args.metadata_license,
|
986 |
+
args.metadata_tags,
|
987 |
+
timesteps=md_timesteps,
|
988 |
+
is_lora=False,
|
989 |
+
)
|
990 |
+
|
991 |
+
save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata)
|
992 |
+
if args.huggingface_repo_id is not None:
|
993 |
+
huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
994 |
+
|
995 |
+
def remove_model(old_ckpt_name):
|
996 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
997 |
+
if os.path.exists(old_ckpt_file):
|
998 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
999 |
+
os.remove(old_ckpt_file)
|
1000 |
+
|
1001 |
+
# For --sample_at_first
|
1002 |
+
optimizer_eval_fn()
|
1003 |
+
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
|
1004 |
+
optimizer_train_fn()
|
1005 |
+
if len(accelerator.trackers) > 0:
|
1006 |
+
# log empty object to commit the sample images to wandb
|
1007 |
+
accelerator.log({}, step=0)
|
1008 |
+
|
1009 |
+
# training loop
|
1010 |
+
|
1011 |
+
# log device and dtype for each model
|
1012 |
+
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
|
1013 |
+
|
1014 |
+
clean_memory_on_device(accelerator.device)
|
1015 |
+
|
1016 |
+
pos_embed_cache = {}
|
1017 |
+
|
1018 |
+
for epoch in range(epoch_to_start, num_train_epochs):
|
1019 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
1020 |
+
current_epoch.value = epoch + 1
|
1021 |
+
|
1022 |
+
for step, batch in enumerate(train_dataloader):
|
1023 |
+
latents, llm_embeds, llm_mask, clip_embeds = batch
|
1024 |
+
bsz = latents.shape[0]
|
1025 |
+
current_step.value = global_step
|
1026 |
+
|
1027 |
+
with accelerator.accumulate(transformer):
|
1028 |
+
latents = latents * vae_module.SCALING_FACTOR
|
1029 |
+
|
1030 |
+
# Sample noise that we'll add to the latents
|
1031 |
+
noise = torch.randn_like(latents)
|
1032 |
+
|
1033 |
+
# calculate model input and timesteps
|
1034 |
+
noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
|
1035 |
+
args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
weighting = compute_loss_weighting_for_sd3(
|
1039 |
+
args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
# ensure guidance_scale in args is float
|
1043 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
|
1044 |
+
|
1045 |
+
# ensure the hidden state will require grad
|
1046 |
+
if args.gradient_checkpointing:
|
1047 |
+
noisy_model_input.requires_grad_(True)
|
1048 |
+
guidance_vec.requires_grad_(True)
|
1049 |
+
|
1050 |
+
pos_emb_shape = latents.shape[1:]
|
1051 |
+
if pos_emb_shape not in pos_embed_cache:
|
1052 |
+
freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(
|
1053 |
+
accelerator.unwrap_model(transformer), latents.shape[2:]
|
1054 |
+
)
|
1055 |
+
# freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
|
1056 |
+
# freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
|
1057 |
+
pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
|
1058 |
+
else:
|
1059 |
+
freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
|
1060 |
+
|
1061 |
+
# call DiT
|
1062 |
+
latents = latents.to(device=accelerator.device, dtype=dit_dtype)
|
1063 |
+
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype)
|
1064 |
+
# timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
|
1065 |
+
# llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
|
1066 |
+
# llm_mask = llm_mask.to(device=accelerator.device)
|
1067 |
+
# clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
|
1068 |
+
with accelerator.autocast():
|
1069 |
+
model_pred = transformer(
|
1070 |
+
noisy_model_input,
|
1071 |
+
timesteps,
|
1072 |
+
text_states=llm_embeds,
|
1073 |
+
text_mask=llm_mask,
|
1074 |
+
text_states_2=clip_embeds,
|
1075 |
+
freqs_cos=freqs_cos,
|
1076 |
+
freqs_sin=freqs_sin,
|
1077 |
+
guidance=guidance_vec,
|
1078 |
+
return_dict=False,
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# flow matching loss
|
1082 |
+
target = noise - latents
|
1083 |
+
|
1084 |
+
loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none")
|
1085 |
+
|
1086 |
+
if weighting is not None:
|
1087 |
+
loss = loss * weighting
|
1088 |
+
# loss = loss.mean([1, 2, 3])
|
1089 |
+
# # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
1090 |
+
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
1091 |
+
|
1092 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
1093 |
+
|
1094 |
+
accelerator.backward(loss)
|
1095 |
+
if accelerator.sync_gradients:
|
1096 |
+
# self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
1097 |
+
state = accelerate.PartialState()
|
1098 |
+
if state.distributed_type != accelerate.DistributedType.NO:
|
1099 |
+
for param in transformer.parameters():
|
1100 |
+
if param.grad is not None:
|
1101 |
+
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
1102 |
+
|
1103 |
+
if args.max_grad_norm != 0.0:
|
1104 |
+
params_to_clip = accelerator.unwrap_model(transformer).parameters()
|
1105 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
1106 |
+
|
1107 |
+
optimizer.step()
|
1108 |
+
lr_scheduler.step()
|
1109 |
+
optimizer.zero_grad(set_to_none=True)
|
1110 |
+
|
1111 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1112 |
+
if accelerator.sync_gradients:
|
1113 |
+
progress_bar.update(1)
|
1114 |
+
global_step += 1
|
1115 |
+
|
1116 |
+
optimizer_eval_fn()
|
1117 |
+
self.sample_images(
|
1118 |
+
accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
|
1119 |
+
)
|
1120 |
+
|
1121 |
+
# 指定ステップごとにモデルを保存
|
1122 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
1123 |
+
accelerator.wait_for_everyone()
|
1124 |
+
if accelerator.is_main_process:
|
1125 |
+
ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
|
1126 |
+
save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch)
|
1127 |
+
|
1128 |
+
if args.save_state:
|
1129 |
+
train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
|
1130 |
+
|
1131 |
+
remove_step_no = train_utils.get_remove_step_no(args, global_step)
|
1132 |
+
if remove_step_no is not None:
|
1133 |
+
remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
|
1134 |
+
remove_model(remove_ckpt_name)
|
1135 |
+
optimizer_train_fn()
|
1136 |
+
|
1137 |
+
current_loss = loss.detach().item()
|
1138 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
1139 |
+
avr_loss: float = loss_recorder.moving_average
|
1140 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
1141 |
+
progress_bar.set_postfix(**logs)
|
1142 |
+
|
1143 |
+
if len(accelerator.trackers) > 0:
|
1144 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
1145 |
+
accelerator.log(logs, step=global_step)
|
1146 |
+
|
1147 |
+
if global_step >= args.max_train_steps:
|
1148 |
+
break
|
1149 |
+
|
1150 |
+
if len(accelerator.trackers) > 0:
|
1151 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
1152 |
+
accelerator.log(logs, step=epoch + 1)
|
1153 |
+
|
1154 |
+
accelerator.wait_for_everyone()
|
1155 |
+
|
1156 |
+
# 指定エポックごとにモデルを保存
|
1157 |
+
optimizer_eval_fn()
|
1158 |
+
if args.save_every_n_epochs is not None:
|
1159 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
1160 |
+
if is_main_process and saving:
|
1161 |
+
ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
|
1162 |
+
save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1)
|
1163 |
+
|
1164 |
+
remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
|
1165 |
+
if remove_epoch_no is not None:
|
1166 |
+
remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
|
1167 |
+
remove_model(remove_ckpt_name)
|
1168 |
+
|
1169 |
+
if args.save_state:
|
1170 |
+
train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
1171 |
+
|
1172 |
+
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
|
1173 |
+
optimizer_train_fn()
|
1174 |
+
|
1175 |
+
# end of epoch
|
1176 |
+
|
1177 |
+
if is_main_process:
|
1178 |
+
transformer = accelerator.unwrap_model(transformer)
|
1179 |
+
|
1180 |
+
accelerator.end_training()
|
1181 |
+
optimizer_eval_fn()
|
1182 |
+
|
1183 |
+
if args.save_state or args.save_state_on_train_end:
|
1184 |
+
train_utils.save_state_on_train_end(args, accelerator)
|
1185 |
+
|
1186 |
+
if is_main_process:
|
1187 |
+
ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
|
1188 |
+
save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True)
|
1189 |
+
|
1190 |
+
logger.info("model saved.")
|
1191 |
+
|
1192 |
+
|
1193 |
+
def setup_parser() -> argparse.ArgumentParser:
|
1194 |
+
def int_or_float(value):
|
1195 |
+
if value.endswith("%"):
|
1196 |
+
try:
|
1197 |
+
return float(value[:-1]) / 100.0
|
1198 |
+
except ValueError:
|
1199 |
+
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
|
1200 |
+
try:
|
1201 |
+
float_value = float(value)
|
1202 |
+
if float_value >= 1 and float_value.is_integer():
|
1203 |
+
return int(value)
|
1204 |
+
return float(value)
|
1205 |
+
except ValueError:
|
1206 |
+
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
|
1207 |
+
|
1208 |
+
parser = argparse.ArgumentParser()
|
1209 |
+
|
1210 |
+
# general settings
|
1211 |
+
parser.add_argument(
|
1212 |
+
"--config_file",
|
1213 |
+
type=str,
|
1214 |
+
default=None,
|
1215 |
+
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
|
1216 |
+
)
|
1217 |
+
parser.add_argument(
|
1218 |
+
"--dataset_config",
|
1219 |
+
type=pathlib.Path,
|
1220 |
+
default=None,
|
1221 |
+
required=True,
|
1222 |
+
help="config file for dataset / データセットの設定ファイル",
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
# training settings
|
1226 |
+
parser.add_argument(
|
1227 |
+
"--sdpa",
|
1228 |
+
action="store_true",
|
1229 |
+
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
|
1230 |
+
)
|
1231 |
+
parser.add_argument(
|
1232 |
+
"--flash_attn",
|
1233 |
+
action="store_true",
|
1234 |
+
help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
|
1235 |
+
)
|
1236 |
+
parser.add_argument(
|
1237 |
+
"--sage_attn",
|
1238 |
+
action="store_true",
|
1239 |
+
help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
|
1240 |
+
)
|
1241 |
+
parser.add_argument(
|
1242 |
+
"--xformers",
|
1243 |
+
action="store_true",
|
1244 |
+
help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要",
|
1245 |
+
)
|
1246 |
+
parser.add_argument(
|
1247 |
+
"--split_attn",
|
1248 |
+
action="store_true",
|
1249 |
+
help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)"
|
1250 |
+
" / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)",
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1254 |
+
parser.add_argument(
|
1255 |
+
"--max_train_epochs",
|
1256 |
+
type=int,
|
1257 |
+
default=None,
|
1258 |
+
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
|
1259 |
+
)
|
1260 |
+
parser.add_argument(
|
1261 |
+
"--max_data_loader_n_workers",
|
1262 |
+
type=int,
|
1263 |
+
default=8,
|
1264 |
+
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
|
1265 |
+
)
|
1266 |
+
parser.add_argument(
|
1267 |
+
"--persistent_data_loader_workers",
|
1268 |
+
action="store_true",
|
1269 |
+
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
|
1270 |
+
)
|
1271 |
+
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
1272 |
+
parser.add_argument(
|
1273 |
+
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
1274 |
+
)
|
1275 |
+
parser.add_argument(
|
1276 |
+
"--gradient_accumulation_steps",
|
1277 |
+
type=int,
|
1278 |
+
default=1,
|
1279 |
+
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
|
1280 |
+
)
|
1281 |
+
parser.add_argument(
|
1282 |
+
"--mixed_precision",
|
1283 |
+
type=str,
|
1284 |
+
default="no",
|
1285 |
+
choices=["no", "fp16", "bf16"],
|
1286 |
+
help="use mixed precision / 混合精度を使う場合、その精度",
|
1287 |
+
)
|
1288 |
+
parser.add_argument("--trainable_modules", nargs="+", default=".", help="Enter a list of trainable modules")
|
1289 |
+
|
1290 |
+
parser.add_argument(
|
1291 |
+
"--logging_dir",
|
1292 |
+
type=str,
|
1293 |
+
default=None,
|
1294 |
+
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
|
1295 |
+
)
|
1296 |
+
parser.add_argument(
|
1297 |
+
"--log_with",
|
1298 |
+
type=str,
|
1299 |
+
default=None,
|
1300 |
+
choices=["tensorboard", "wandb", "all"],
|
1301 |
+
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
|
1302 |
+
)
|
1303 |
+
parser.add_argument(
|
1304 |
+
"--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
|
1305 |
+
)
|
1306 |
+
parser.add_argument(
|
1307 |
+
"--log_tracker_name",
|
1308 |
+
type=str,
|
1309 |
+
default=None,
|
1310 |
+
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
|
1311 |
+
)
|
1312 |
+
parser.add_argument(
|
1313 |
+
"--wandb_run_name",
|
1314 |
+
type=str,
|
1315 |
+
default=None,
|
1316 |
+
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
|
1317 |
+
)
|
1318 |
+
parser.add_argument(
|
1319 |
+
"--log_tracker_config",
|
1320 |
+
type=str,
|
1321 |
+
default=None,
|
1322 |
+
help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
|
1323 |
+
)
|
1324 |
+
parser.add_argument(
|
1325 |
+
"--wandb_api_key",
|
1326 |
+
type=str,
|
1327 |
+
default=None,
|
1328 |
+
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
|
1329 |
+
)
|
1330 |
+
parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
|
1331 |
+
|
1332 |
+
parser.add_argument(
|
1333 |
+
"--ddp_timeout",
|
1334 |
+
type=int,
|
1335 |
+
default=None,
|
1336 |
+
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
1337 |
+
)
|
1338 |
+
parser.add_argument(
|
1339 |
+
"--ddp_gradient_as_bucket_view",
|
1340 |
+
action="store_true",
|
1341 |
+
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
|
1342 |
+
)
|
1343 |
+
parser.add_argument(
|
1344 |
+
"--ddp_static_graph",
|
1345 |
+
action="store_true",
|
1346 |
+
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
parser.add_argument(
|
1350 |
+
"--sample_every_n_steps",
|
1351 |
+
type=int,
|
1352 |
+
default=None,
|
1353 |
+
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
|
1354 |
+
)
|
1355 |
+
parser.add_argument(
|
1356 |
+
"--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
|
1357 |
+
)
|
1358 |
+
parser.add_argument(
|
1359 |
+
"--sample_every_n_epochs",
|
1360 |
+
type=int,
|
1361 |
+
default=None,
|
1362 |
+
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
|
1363 |
+
)
|
1364 |
+
parser.add_argument(
|
1365 |
+
"--sample_prompts",
|
1366 |
+
type=str,
|
1367 |
+
default=None,
|
1368 |
+
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
|
1369 |
+
)
|
1370 |
+
|
1371 |
+
# optimizer and lr scheduler settings
|
1372 |
+
parser.add_argument(
|
1373 |
+
"--optimizer_type",
|
1374 |
+
type=str,
|
1375 |
+
default="",
|
1376 |
+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
|
1377 |
+
"Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
|
1378 |
+
)
|
1379 |
+
parser.add_argument(
|
1380 |
+
"--optimizer_args",
|
1381 |
+
type=str,
|
1382 |
+
default=None,
|
1383 |
+
nargs="*",
|
1384 |
+
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
1385 |
+
)
|
1386 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1387 |
+
parser.add_argument(
|
1388 |
+
"--max_grad_norm",
|
1389 |
+
default=1.0,
|
1390 |
+
type=float,
|
1391 |
+
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
|
1392 |
+
)
|
1393 |
+
|
1394 |
+
parser.add_argument(
|
1395 |
+
"--lr_scheduler",
|
1396 |
+
type=str,
|
1397 |
+
default="constant",
|
1398 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
|
1399 |
+
)
|
1400 |
+
parser.add_argument(
|
1401 |
+
"--lr_warmup_steps",
|
1402 |
+
type=int_or_float,
|
1403 |
+
default=0,
|
1404 |
+
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
|
1405 |
+
" / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
1406 |
+
)
|
1407 |
+
parser.add_argument(
|
1408 |
+
"--lr_decay_steps",
|
1409 |
+
type=int_or_float,
|
1410 |
+
default=0,
|
1411 |
+
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
|
1412 |
+
" / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
1413 |
+
)
|
1414 |
+
parser.add_argument(
|
1415 |
+
"--lr_scheduler_num_cycles",
|
1416 |
+
type=int,
|
1417 |
+
default=1,
|
1418 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
|
1419 |
+
)
|
1420 |
+
parser.add_argument(
|
1421 |
+
"--lr_scheduler_power",
|
1422 |
+
type=float,
|
1423 |
+
default=1,
|
1424 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
1425 |
+
)
|
1426 |
+
parser.add_argument(
|
1427 |
+
"--lr_scheduler_timescale",
|
1428 |
+
type=int,
|
1429 |
+
default=None,
|
1430 |
+
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
|
1431 |
+
+ " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
|
1432 |
+
)
|
1433 |
+
parser.add_argument(
|
1434 |
+
"--lr_scheduler_min_lr_ratio",
|
1435 |
+
type=float,
|
1436 |
+
default=None,
|
1437 |
+
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
|
1438 |
+
+ " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
|
1439 |
+
)
|
1440 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
1441 |
+
parser.add_argument(
|
1442 |
+
"--lr_scheduler_args",
|
1443 |
+
type=str,
|
1444 |
+
default=None,
|
1445 |
+
nargs="*",
|
1446 |
+
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
|
1447 |
+
)
|
1448 |
+
|
1449 |
+
# model settings
|
1450 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
|
1451 |
+
parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
|
1452 |
+
parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
|
1453 |
+
parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
|
1454 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
1455 |
+
parser.add_argument(
|
1456 |
+
"--vae_tiling",
|
1457 |
+
action="store_true",
|
1458 |
+
help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
|
1459 |
+
" / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
|
1460 |
+
)
|
1461 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
1462 |
+
parser.add_argument(
|
1463 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
1464 |
+
)
|
1465 |
+
parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
|
1466 |
+
parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
|
1467 |
+
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
|
1468 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
|
1469 |
+
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
1470 |
+
parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
|
1471 |
+
|
1472 |
+
parser.add_argument(
|
1473 |
+
"--blocks_to_swap",
|
1474 |
+
type=int,
|
1475 |
+
default=None,
|
1476 |
+
help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
|
1477 |
+
)
|
1478 |
+
parser.add_argument(
|
1479 |
+
"--img_in_txt_in_offloading",
|
1480 |
+
action="store_true",
|
1481 |
+
help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
|
1482 |
+
)
|
1483 |
+
|
1484 |
+
# parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
|
1485 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
|
1486 |
+
parser.add_argument(
|
1487 |
+
"--timestep_sampling",
|
1488 |
+
choices=["sigma", "uniform", "sigmoid", "shift"],
|
1489 |
+
default="sigma",
|
1490 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
|
1491 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
|
1492 |
+
)
|
1493 |
+
parser.add_argument(
|
1494 |
+
"--discrete_flow_shift",
|
1495 |
+
type=float,
|
1496 |
+
default=1.0,
|
1497 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
|
1498 |
+
)
|
1499 |
+
parser.add_argument(
|
1500 |
+
"--sigmoid_scale",
|
1501 |
+
type=float,
|
1502 |
+
default=1.0,
|
1503 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
|
1504 |
+
)
|
1505 |
+
parser.add_argument(
|
1506 |
+
"--weighting_scheme",
|
1507 |
+
type=str,
|
1508 |
+
default="none",
|
1509 |
+
choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
|
1510 |
+
help="weighting scheme for timestep distribution. Default is none"
|
1511 |
+
" / タイムステップ分布の重み付けスキーム、デフォルトはnone",
|
1512 |
+
)
|
1513 |
+
parser.add_argument(
|
1514 |
+
"--logit_mean",
|
1515 |
+
type=float,
|
1516 |
+
default=0.0,
|
1517 |
+
help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
|
1518 |
+
)
|
1519 |
+
parser.add_argument(
|
1520 |
+
"--logit_std",
|
1521 |
+
type=float,
|
1522 |
+
default=1.0,
|
1523 |
+
help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
|
1524 |
+
)
|
1525 |
+
parser.add_argument(
|
1526 |
+
"--mode_scale",
|
1527 |
+
type=float,
|
1528 |
+
default=1.29,
|
1529 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
|
1530 |
+
)
|
1531 |
+
parser.add_argument(
|
1532 |
+
"--min_timestep",
|
1533 |
+
type=int,
|
1534 |
+
default=None,
|
1535 |
+
help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
|
1536 |
+
)
|
1537 |
+
parser.add_argument(
|
1538 |
+
"--max_timestep",
|
1539 |
+
type=int,
|
1540 |
+
default=None,
|
1541 |
+
help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
|
1542 |
+
)
|
1543 |
+
|
1544 |
+
# save and load settings
|
1545 |
+
parser.add_argument(
|
1546 |
+
"--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
|
1547 |
+
)
|
1548 |
+
parser.add_argument(
|
1549 |
+
"--output_name",
|
1550 |
+
type=str,
|
1551 |
+
default=None,
|
1552 |
+
required=True,
|
1553 |
+
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
|
1554 |
+
)
|
1555 |
+
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
1556 |
+
|
1557 |
+
parser.add_argument(
|
1558 |
+
"--save_every_n_epochs",
|
1559 |
+
type=int,
|
1560 |
+
default=None,
|
1561 |
+
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
|
1562 |
+
)
|
1563 |
+
parser.add_argument(
|
1564 |
+
"--save_every_n_steps",
|
1565 |
+
type=int,
|
1566 |
+
default=None,
|
1567 |
+
help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
|
1568 |
+
)
|
1569 |
+
parser.add_argument(
|
1570 |
+
"--save_last_n_epochs",
|
1571 |
+
type=int,
|
1572 |
+
default=None,
|
1573 |
+
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
|
1574 |
+
)
|
1575 |
+
parser.add_argument(
|
1576 |
+
"--save_last_n_epochs_state",
|
1577 |
+
type=int,
|
1578 |
+
default=None,
|
1579 |
+
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
|
1580 |
+
)
|
1581 |
+
parser.add_argument(
|
1582 |
+
"--save_last_n_steps",
|
1583 |
+
type=int,
|
1584 |
+
default=None,
|
1585 |
+
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
|
1586 |
+
)
|
1587 |
+
parser.add_argument(
|
1588 |
+
"--save_last_n_steps_state",
|
1589 |
+
type=int,
|
1590 |
+
default=None,
|
1591 |
+
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
|
1592 |
+
)
|
1593 |
+
parser.add_argument(
|
1594 |
+
"--save_state",
|
1595 |
+
action="store_true",
|
1596 |
+
help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
|
1597 |
+
)
|
1598 |
+
parser.add_argument(
|
1599 |
+
"--save_state_on_train_end",
|
1600 |
+
action="store_true",
|
1601 |
+
help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
|
1602 |
+
" / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
|
1603 |
+
)
|
1604 |
+
|
1605 |
+
# SAI Model spec
|
1606 |
+
parser.add_argument(
|
1607 |
+
"--metadata_title",
|
1608 |
+
type=str,
|
1609 |
+
default=None,
|
1610 |
+
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
|
1611 |
+
)
|
1612 |
+
parser.add_argument(
|
1613 |
+
"--metadata_author",
|
1614 |
+
type=str,
|
1615 |
+
default=None,
|
1616 |
+
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
|
1617 |
+
)
|
1618 |
+
parser.add_argument(
|
1619 |
+
"--metadata_description",
|
1620 |
+
type=str,
|
1621 |
+
default=None,
|
1622 |
+
help="description for model metadata / メタデータに書き込まれるモデル説明",
|
1623 |
+
)
|
1624 |
+
parser.add_argument(
|
1625 |
+
"--metadata_license",
|
1626 |
+
type=str,
|
1627 |
+
default=None,
|
1628 |
+
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
|
1629 |
+
)
|
1630 |
+
parser.add_argument(
|
1631 |
+
"--metadata_tags",
|
1632 |
+
type=str,
|
1633 |
+
default=None,
|
1634 |
+
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
1635 |
+
)
|
1636 |
+
|
1637 |
+
# huggingface settings
|
1638 |
+
parser.add_argument(
|
1639 |
+
"--huggingface_repo_id",
|
1640 |
+
type=str,
|
1641 |
+
default=None,
|
1642 |
+
help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
|
1643 |
+
)
|
1644 |
+
parser.add_argument(
|
1645 |
+
"--huggingface_repo_type",
|
1646 |
+
type=str,
|
1647 |
+
default=None,
|
1648 |
+
help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
|
1649 |
+
)
|
1650 |
+
parser.add_argument(
|
1651 |
+
"--huggingface_path_in_repo",
|
1652 |
+
type=str,
|
1653 |
+
default=None,
|
1654 |
+
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
|
1655 |
+
)
|
1656 |
+
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
1657 |
+
parser.add_argument(
|
1658 |
+
"--huggingface_repo_visibility",
|
1659 |
+
type=str,
|
1660 |
+
default=None,
|
1661 |
+
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
|
1662 |
+
)
|
1663 |
+
parser.add_argument(
|
1664 |
+
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
|
1665 |
+
)
|
1666 |
+
parser.add_argument(
|
1667 |
+
"--resume_from_huggingface",
|
1668 |
+
action="store_true",
|
1669 |
+
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
1670 |
+
)
|
1671 |
+
parser.add_argument(
|
1672 |
+
"--async_upload",
|
1673 |
+
action="store_true",
|
1674 |
+
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
|
1675 |
+
)
|
1676 |
+
|
1677 |
+
return parser
|
1678 |
+
|
1679 |
+
|
1680 |
+
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
1681 |
+
if not args.config_file:
|
1682 |
+
return args
|
1683 |
+
|
1684 |
+
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
1685 |
+
|
1686 |
+
if not os.path.exists(config_path):
|
1687 |
+
logger.info(f"{config_path} not found.")
|
1688 |
+
exit(1)
|
1689 |
+
|
1690 |
+
logger.info(f"Loading settings from {config_path}...")
|
1691 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
1692 |
+
config_dict = toml.load(f)
|
1693 |
+
|
1694 |
+
# combine all sections into one
|
1695 |
+
ignore_nesting_dict = {}
|
1696 |
+
for section_name, section_dict in config_dict.items():
|
1697 |
+
# if value is not dict, save key and value as is
|
1698 |
+
if not isinstance(section_dict, dict):
|
1699 |
+
ignore_nesting_dict[section_name] = section_dict
|
1700 |
+
continue
|
1701 |
+
|
1702 |
+
# if value is dict, save all key and value into one dict
|
1703 |
+
for key, value in section_dict.items():
|
1704 |
+
ignore_nesting_dict[key] = value
|
1705 |
+
|
1706 |
+
config_args = argparse.Namespace(**ignore_nesting_dict)
|
1707 |
+
args = parser.parse_args(namespace=config_args)
|
1708 |
+
args.config_file = os.path.splitext(args.config_file)[0]
|
1709 |
+
logger.info(args.config_file)
|
1710 |
+
|
1711 |
+
return args
|
1712 |
+
|
1713 |
+
|
1714 |
+
if __name__ == "__main__":
|
1715 |
+
parser = setup_parser()
|
1716 |
+
|
1717 |
+
args = parser.parse_args()
|
1718 |
+
args = read_config_from_file(args, parser)
|
1719 |
+
|
1720 |
+
trainer = FineTuningTrainer()
|
1721 |
+
trainer.train(args)
|
hv_train_network.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
wan_cache_latents.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from dataset import config_utils
|
11 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import logging
|
15 |
+
|
16 |
+
from dataset.image_video_dataset import ItemInfo, save_latent_cache_wan, ARCHITECTURE_WAN
|
17 |
+
from utils.model_utils import str_to_dtype
|
18 |
+
from wan.configs import wan_i2v_14B
|
19 |
+
from wan.modules.vae import WanVAE
|
20 |
+
from wan.modules.clip import CLIPModel
|
21 |
+
import cache_latents
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
|
26 |
+
|
27 |
+
def encode_and_save_batch(vae: WanVAE, clip: Optional[CLIPModel], batch: list[ItemInfo]):
|
28 |
+
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
29 |
+
if len(contents.shape) == 4:
|
30 |
+
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
|
31 |
+
|
32 |
+
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
33 |
+
contents = contents.to(vae.device, dtype=vae.dtype)
|
34 |
+
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
35 |
+
|
36 |
+
h, w = contents.shape[3], contents.shape[4]
|
37 |
+
if h < 8 or w < 8:
|
38 |
+
item = batch[0] # other items should have the same size
|
39 |
+
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
40 |
+
|
41 |
+
# print(f"encode batch: {contents.shape}")
|
42 |
+
with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
|
43 |
+
latent = vae.encode(contents) # list of Tensor[C, F, H, W]
|
44 |
+
latent = torch.stack(latent, dim=0) # B, C, F, H, W
|
45 |
+
latent = latent.to(vae.dtype) # convert to bfloat16, we are not sure if this is correct
|
46 |
+
|
47 |
+
if clip is not None:
|
48 |
+
# extract first frame of contents
|
49 |
+
images = contents[:, :, 0:1, :, :] # B, C, F, H, W, non contiguous view is fine
|
50 |
+
|
51 |
+
with torch.amp.autocast(device_type=clip.device.type, dtype=torch.float16), torch.no_grad():
|
52 |
+
clip_context = clip.visual(images)
|
53 |
+
clip_context = clip_context.to(torch.float16) # convert to fp16
|
54 |
+
|
55 |
+
# encode image latent for I2V
|
56 |
+
B, _, _, lat_h, lat_w = latent.shape
|
57 |
+
F = contents.shape[2]
|
58 |
+
|
59 |
+
# Create mask for the required number of frames
|
60 |
+
msk = torch.ones(1, F, lat_h, lat_w, dtype=vae.dtype, device=vae.device)
|
61 |
+
msk[:, 1:] = 0
|
62 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
63 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
64 |
+
msk = msk.transpose(1, 2) # 1, F, 4, H, W -> 1, 4, F, H, W
|
65 |
+
msk = msk.repeat(B, 1, 1, 1, 1) # B, 4, F, H, W
|
66 |
+
|
67 |
+
# Zero padding for the required number of frames only
|
68 |
+
padding_frames = F - 1 # The first frame is the input image
|
69 |
+
images_resized = torch.concat([images, torch.zeros(B, 3, padding_frames, h, w, device=vae.device)], dim=2)
|
70 |
+
with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
|
71 |
+
y = vae.encode(images_resized)
|
72 |
+
y = torch.stack(y, dim=0) # B, C, F, H, W
|
73 |
+
|
74 |
+
y = y[:, :, :F] # may be not needed
|
75 |
+
y = y.to(vae.dtype) # convert to bfloat16
|
76 |
+
y = torch.concat([msk, y], dim=1) # B, 4 + C, F, H, W
|
77 |
+
|
78 |
+
else:
|
79 |
+
clip_context = None
|
80 |
+
y = None
|
81 |
+
|
82 |
+
# control videos
|
83 |
+
if batch[0].control_content is not None:
|
84 |
+
control_contents = torch.stack([torch.from_numpy(item.control_content) for item in batch])
|
85 |
+
if len(control_contents.shape) == 4:
|
86 |
+
control_contents = control_contents.unsqueeze(1)
|
87 |
+
control_contents = control_contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
88 |
+
control_contents = control_contents.to(vae.device, dtype=vae.dtype)
|
89 |
+
control_contents = control_contents / 127.5 - 1.0 # normalize to [-1, 1]
|
90 |
+
with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
|
91 |
+
control_latent = vae.encode(control_contents) # list of Tensor[C, F, H, W]
|
92 |
+
control_latent = torch.stack(control_latent, dim=0) # B, C, F, H, W
|
93 |
+
control_latent = control_latent.to(vae.dtype) # convert to bfloat16
|
94 |
+
else:
|
95 |
+
control_latent = None
|
96 |
+
|
97 |
+
# # debug: decode and save
|
98 |
+
# with torch.no_grad():
|
99 |
+
# latent_to_decode = latent / vae.config.scaling_factor
|
100 |
+
# images = vae.decode(latent_to_decode, return_dict=False)[0]
|
101 |
+
# images = (images / 2 + 0.5).clamp(0, 1)
|
102 |
+
# images = images.cpu().float().numpy()
|
103 |
+
# images = (images * 255).astype(np.uint8)
|
104 |
+
# images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
|
105 |
+
# for b in range(images.shape[0]):
|
106 |
+
# for f in range(images.shape[1]):
|
107 |
+
# fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
|
108 |
+
# img = Image.fromarray(images[b, f])
|
109 |
+
# img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
|
110 |
+
|
111 |
+
for i, item in enumerate(batch):
|
112 |
+
l = latent[i]
|
113 |
+
cctx = clip_context[i] if clip is not None else None
|
114 |
+
y_i = y[i] if clip is not None else None
|
115 |
+
control_latent_i = control_latent[i] if control_latent is not None else None
|
116 |
+
# print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
|
117 |
+
save_latent_cache_wan(item, l, cctx, y_i, control_latent_i)
|
118 |
+
|
119 |
+
|
120 |
+
def main(args):
|
121 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
122 |
+
device = torch.device(device)
|
123 |
+
|
124 |
+
# Load dataset config
|
125 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
126 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
127 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
128 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
|
129 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
130 |
+
|
131 |
+
datasets = train_dataset_group.datasets
|
132 |
+
|
133 |
+
if args.debug_mode is not None:
|
134 |
+
cache_latents.show_datasets(
|
135 |
+
datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
|
136 |
+
)
|
137 |
+
return
|
138 |
+
|
139 |
+
assert args.vae is not None, "vae checkpoint is required"
|
140 |
+
|
141 |
+
vae_path = args.vae
|
142 |
+
|
143 |
+
logger.info(f"Loading VAE model from {vae_path}")
|
144 |
+
vae_dtype = torch.bfloat16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
145 |
+
cache_device = torch.device("cpu") if args.vae_cache_cpu else None
|
146 |
+
vae = WanVAE(vae_path=vae_path, device=device, dtype=vae_dtype, cache_device=cache_device)
|
147 |
+
|
148 |
+
if args.clip is not None:
|
149 |
+
clip_dtype = wan_i2v_14B.i2v_14B["clip_dtype"]
|
150 |
+
clip = CLIPModel(dtype=clip_dtype, device=device, weight_path=args.clip)
|
151 |
+
else:
|
152 |
+
clip = None
|
153 |
+
|
154 |
+
# Encode images
|
155 |
+
def encode(one_batch: list[ItemInfo]):
|
156 |
+
encode_and_save_batch(vae, clip, one_batch)
|
157 |
+
|
158 |
+
cache_latents.encode_datasets(datasets, encode, args)
|
159 |
+
|
160 |
+
|
161 |
+
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
162 |
+
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
|
163 |
+
parser.add_argument(
|
164 |
+
"--clip",
|
165 |
+
type=str,
|
166 |
+
default=None,
|
167 |
+
help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
|
168 |
+
)
|
169 |
+
return parser
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == "__main__":
|
173 |
+
parser = cache_latents.setup_parser_common()
|
174 |
+
parser = wan_setup_parser(parser)
|
175 |
+
|
176 |
+
args = parser.parse_args()
|
177 |
+
main(args)
|
wan_cache_text_encoder_outputs.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from dataset import config_utils
|
10 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
11 |
+
import accelerate
|
12 |
+
|
13 |
+
from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan
|
14 |
+
|
15 |
+
# for t5 config: all Wan2.1 models have the same config for t5
|
16 |
+
from wan.configs import wan_t2v_14B
|
17 |
+
|
18 |
+
import cache_text_encoder_outputs
|
19 |
+
import logging
|
20 |
+
|
21 |
+
from utils.model_utils import str_to_dtype
|
22 |
+
from wan.modules.t5 import T5EncoderModel
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
logging.basicConfig(level=logging.INFO)
|
26 |
+
|
27 |
+
|
28 |
+
def encode_and_save_batch(
|
29 |
+
text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator]
|
30 |
+
):
|
31 |
+
prompts = [item.caption for item in batch]
|
32 |
+
# print(prompts)
|
33 |
+
|
34 |
+
# encode prompt
|
35 |
+
with torch.no_grad():
|
36 |
+
if accelerator is not None:
|
37 |
+
with accelerator.autocast():
|
38 |
+
context = text_encoder(prompts, device)
|
39 |
+
else:
|
40 |
+
context = text_encoder(prompts, device)
|
41 |
+
|
42 |
+
# save prompt cache
|
43 |
+
for item, ctx in zip(batch, context):
|
44 |
+
save_text_encoder_output_cache_wan(item, ctx)
|
45 |
+
|
46 |
+
|
47 |
+
def main(args):
|
48 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
device = torch.device(device)
|
50 |
+
|
51 |
+
# Load dataset config
|
52 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
53 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
54 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
55 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
|
56 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
57 |
+
|
58 |
+
datasets = train_dataset_group.datasets
|
59 |
+
|
60 |
+
# define accelerator for fp8 inference
|
61 |
+
config = wan_t2v_14B.t2v_14B # all Wan2.1 models have the same config for t5
|
62 |
+
accelerator = None
|
63 |
+
if args.fp8_t5:
|
64 |
+
accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16")
|
65 |
+
|
66 |
+
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
|
67 |
+
all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
|
68 |
+
|
69 |
+
# Load T5
|
70 |
+
logger.info(f"Loading T5: {args.t5}")
|
71 |
+
text_encoder = T5EncoderModel(
|
72 |
+
text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5
|
73 |
+
)
|
74 |
+
|
75 |
+
# Encode with T5
|
76 |
+
logger.info("Encoding with T5")
|
77 |
+
|
78 |
+
def encode_for_text_encoder(batch: list[ItemInfo]):
|
79 |
+
encode_and_save_batch(text_encoder, batch, device, accelerator)
|
80 |
+
|
81 |
+
cache_text_encoder_outputs.process_text_encoder_batches(
|
82 |
+
args.num_workers,
|
83 |
+
args.skip_existing,
|
84 |
+
args.batch_size,
|
85 |
+
datasets,
|
86 |
+
all_cache_files_for_dataset,
|
87 |
+
all_cache_paths_for_dataset,
|
88 |
+
encode_for_text_encoder,
|
89 |
+
)
|
90 |
+
del text_encoder
|
91 |
+
|
92 |
+
# remove cache files not in dataset
|
93 |
+
cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
|
94 |
+
|
95 |
+
|
96 |
+
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
97 |
+
parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path")
|
98 |
+
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
|
99 |
+
return parser
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
parser = cache_text_encoder_outputs.setup_parser_common()
|
104 |
+
parser = wan_setup_parser(parser)
|
105 |
+
|
106 |
+
args = parser.parse_args()
|
107 |
+
main(args)
|
wan_generate_video.py
ADDED
@@ -0,0 +1,1902 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datetime import datetime
|
3 |
+
import gc
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import copy
|
10 |
+
from types import ModuleType, SimpleNamespace
|
11 |
+
from typing import Tuple, Optional, List, Union, Any, Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import accelerate
|
15 |
+
from accelerate import Accelerator
|
16 |
+
from safetensors.torch import load_file, save_file
|
17 |
+
from safetensors import safe_open
|
18 |
+
from PIL import Image
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import torchvision.transforms.functional as TF
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from networks import lora_wan
|
25 |
+
from utils.safetensors_utils import mem_eff_save_file, load_safetensors
|
26 |
+
from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES
|
27 |
+
import wan
|
28 |
+
from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype
|
29 |
+
from wan.modules.vae import WanVAE
|
30 |
+
from wan.modules.t5 import T5EncoderModel
|
31 |
+
from wan.modules.clip import CLIPModel
|
32 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
33 |
+
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
34 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
35 |
+
|
36 |
+
try:
|
37 |
+
from lycoris.kohya import create_network_from_weights
|
38 |
+
except:
|
39 |
+
pass
|
40 |
+
|
41 |
+
from utils.model_utils import str_to_dtype
|
42 |
+
from utils.device_utils import clean_memory_on_device
|
43 |
+
from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
|
44 |
+
from dataset.image_video_dataset import load_video
|
45 |
+
|
46 |
+
import logging
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
logging.basicConfig(level=logging.INFO)
|
50 |
+
|
51 |
+
|
52 |
+
class GenerationSettings:
|
53 |
+
def __init__(
|
54 |
+
self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype
|
55 |
+
):
|
56 |
+
self.device = device
|
57 |
+
self.cfg = cfg
|
58 |
+
self.dit_dtype = dit_dtype
|
59 |
+
self.dit_weight_dtype = dit_weight_dtype
|
60 |
+
self.vae_dtype = vae_dtype
|
61 |
+
|
62 |
+
|
63 |
+
def parse_args() -> argparse.Namespace:
|
64 |
+
"""parse command line arguments"""
|
65 |
+
parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
|
66 |
+
|
67 |
+
# WAN arguments
|
68 |
+
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
|
69 |
+
parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
|
70 |
+
parser.add_argument(
|
71 |
+
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path")
|
75 |
+
parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path")
|
76 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16")
|
77 |
+
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
|
78 |
+
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
|
79 |
+
parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path")
|
80 |
+
# LoRA
|
81 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
82 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
|
83 |
+
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
|
84 |
+
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
|
85 |
+
parser.add_argument(
|
86 |
+
"--save_merged_model",
|
87 |
+
type=str,
|
88 |
+
default=None,
|
89 |
+
help="Save merged model to path. If specified, no inference will be performed.",
|
90 |
+
)
|
91 |
+
|
92 |
+
# inference
|
93 |
+
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
|
94 |
+
parser.add_argument(
|
95 |
+
"--negative_prompt",
|
96 |
+
type=str,
|
97 |
+
default=None,
|
98 |
+
help="negative prompt for generation, use default negative prompt if not specified",
|
99 |
+
)
|
100 |
+
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
|
101 |
+
parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task")
|
102 |
+
parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16")
|
103 |
+
parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps")
|
104 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
|
105 |
+
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
106 |
+
parser.add_argument(
|
107 |
+
"--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--guidance_scale",
|
111 |
+
type=float,
|
112 |
+
default=5.0,
|
113 |
+
help="Guidance scale for classifier free guidance. Default is 5.0.",
|
114 |
+
)
|
115 |
+
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
|
116 |
+
parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference")
|
117 |
+
parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
|
118 |
+
parser.add_argument(
|
119 |
+
"--control_path",
|
120 |
+
type=str,
|
121 |
+
default=None,
|
122 |
+
help="path to control video for inference with controlnet. video file or directory with images",
|
123 |
+
)
|
124 |
+
parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
|
125 |
+
parser.add_argument(
|
126 |
+
"--cfg_skip_mode",
|
127 |
+
type=str,
|
128 |
+
default="none",
|
129 |
+
choices=["early", "late", "middle", "early_late", "alternate", "none"],
|
130 |
+
help="CFG skip mode. each mode skips different parts of the CFG. "
|
131 |
+
" early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--cfg_apply_ratio",
|
135 |
+
type=float,
|
136 |
+
default=None,
|
137 |
+
help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated"
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--slg_scale",
|
144 |
+
type=float,
|
145 |
+
default=3.0,
|
146 |
+
help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond",
|
147 |
+
)
|
148 |
+
parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.")
|
149 |
+
parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.")
|
150 |
+
parser.add_argument(
|
151 |
+
"--slg_mode",
|
152 |
+
type=str,
|
153 |
+
default=None,
|
154 |
+
choices=["original", "uncond"],
|
155 |
+
help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred",
|
156 |
+
)
|
157 |
+
|
158 |
+
# Flow Matching
|
159 |
+
parser.add_argument(
|
160 |
+
"--flow_shift",
|
161 |
+
type=float,
|
162 |
+
default=None,
|
163 |
+
help="Shift factor for flow matching schedulers. Default depends on task.",
|
164 |
+
)
|
165 |
+
|
166 |
+
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
167 |
+
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
168 |
+
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
|
169 |
+
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
|
170 |
+
parser.add_argument(
|
171 |
+
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--attn_mode",
|
175 |
+
type=str,
|
176 |
+
default="torch",
|
177 |
+
choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"],
|
178 |
+
help="attention mode",
|
179 |
+
)
|
180 |
+
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
181 |
+
parser.add_argument(
|
182 |
+
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
|
183 |
+
)
|
184 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
|
185 |
+
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
|
186 |
+
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
|
187 |
+
parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
|
188 |
+
parser.add_argument(
|
189 |
+
"--compile_args",
|
190 |
+
nargs=4,
|
191 |
+
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
|
192 |
+
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
|
193 |
+
help="Torch.compile settings",
|
194 |
+
)
|
195 |
+
|
196 |
+
# New arguments for batch and interactive modes
|
197 |
+
parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
|
198 |
+
parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
|
199 |
+
|
200 |
+
args = parser.parse_args()
|
201 |
+
|
202 |
+
# Validate arguments
|
203 |
+
if args.from_file and args.interactive:
|
204 |
+
raise ValueError("Cannot use both --from_file and --interactive at the same time")
|
205 |
+
|
206 |
+
if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None:
|
207 |
+
raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified")
|
208 |
+
|
209 |
+
assert (args.latent_path is None or len(args.latent_path) == 0) or (
|
210 |
+
args.output_type == "images" or args.output_type == "video"
|
211 |
+
), "latent_path is only supported for images or video output"
|
212 |
+
|
213 |
+
return args
|
214 |
+
|
215 |
+
|
216 |
+
def parse_prompt_line(line: str) -> Dict[str, Any]:
|
217 |
+
"""Parse a prompt line into a dictionary of argument overrides
|
218 |
+
|
219 |
+
Args:
|
220 |
+
line: Prompt line with options
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Dict[str, Any]: Dictionary of argument overrides
|
224 |
+
"""
|
225 |
+
# TODO common function with hv_train_network.line_to_prompt_dict
|
226 |
+
parts = line.split(" --")
|
227 |
+
prompt = parts[0].strip()
|
228 |
+
|
229 |
+
# Create dictionary of overrides
|
230 |
+
overrides = {"prompt": prompt}
|
231 |
+
|
232 |
+
for part in parts[1:]:
|
233 |
+
if not part.strip():
|
234 |
+
continue
|
235 |
+
option_parts = part.split(" ", 1)
|
236 |
+
option = option_parts[0].strip()
|
237 |
+
value = option_parts[1].strip() if len(option_parts) > 1 else ""
|
238 |
+
|
239 |
+
# Map options to argument names
|
240 |
+
if option == "w":
|
241 |
+
overrides["video_size_width"] = int(value)
|
242 |
+
elif option == "h":
|
243 |
+
overrides["video_size_height"] = int(value)
|
244 |
+
elif option == "f":
|
245 |
+
overrides["video_length"] = int(value)
|
246 |
+
elif option == "d":
|
247 |
+
overrides["seed"] = int(value)
|
248 |
+
elif option == "s":
|
249 |
+
overrides["infer_steps"] = int(value)
|
250 |
+
elif option == "g" or option == "l":
|
251 |
+
overrides["guidance_scale"] = float(value)
|
252 |
+
elif option == "fs":
|
253 |
+
overrides["flow_shift"] = float(value)
|
254 |
+
elif option == "i":
|
255 |
+
overrides["image_path"] = value
|
256 |
+
elif option == "cn":
|
257 |
+
overrides["control_path"] = value
|
258 |
+
elif option == "n":
|
259 |
+
overrides["negative_prompt"] = value
|
260 |
+
|
261 |
+
return overrides
|
262 |
+
|
263 |
+
|
264 |
+
def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
|
265 |
+
"""Apply overrides to args
|
266 |
+
|
267 |
+
Args:
|
268 |
+
args: Original arguments
|
269 |
+
overrides: Dictionary of overrides
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
argparse.Namespace: New arguments with overrides applied
|
273 |
+
"""
|
274 |
+
args_copy = copy.deepcopy(args)
|
275 |
+
|
276 |
+
for key, value in overrides.items():
|
277 |
+
if key == "video_size_width":
|
278 |
+
args_copy.video_size[1] = value
|
279 |
+
elif key == "video_size_height":
|
280 |
+
args_copy.video_size[0] = value
|
281 |
+
else:
|
282 |
+
setattr(args_copy, key, value)
|
283 |
+
|
284 |
+
return args_copy
|
285 |
+
|
286 |
+
|
287 |
+
def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]:
|
288 |
+
"""Return default values for each task
|
289 |
+
|
290 |
+
Args:
|
291 |
+
task: task name (t2v, t2i, i2v etc.)
|
292 |
+
size: size of the video (width, height)
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip)
|
296 |
+
"""
|
297 |
+
width, height = size if size else (0, 0)
|
298 |
+
|
299 |
+
if "t2i" in task:
|
300 |
+
return 50, 5.0, 1, False
|
301 |
+
elif "i2v" in task:
|
302 |
+
flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0
|
303 |
+
return 40, flow_shift, 81, True
|
304 |
+
else: # t2v or default
|
305 |
+
return 50, 5.0, 81, False
|
306 |
+
|
307 |
+
|
308 |
+
def setup_args(args: argparse.Namespace) -> argparse.Namespace:
|
309 |
+
"""Validate and set default values for optional arguments
|
310 |
+
|
311 |
+
Args:
|
312 |
+
args: command line arguments
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
argparse.Namespace: updated arguments
|
316 |
+
"""
|
317 |
+
# Get default values for the task
|
318 |
+
infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size))
|
319 |
+
|
320 |
+
# Apply default values to unset arguments
|
321 |
+
if args.infer_steps is None:
|
322 |
+
args.infer_steps = infer_steps
|
323 |
+
if args.flow_shift is None:
|
324 |
+
args.flow_shift = flow_shift
|
325 |
+
if args.video_length is None:
|
326 |
+
args.video_length = video_length
|
327 |
+
|
328 |
+
# Force video_length to 1 for t2i tasks
|
329 |
+
if "t2i" in args.task:
|
330 |
+
assert args.video_length == 1, f"video_length should be 1 for task {args.task}"
|
331 |
+
|
332 |
+
# parse slg_layers
|
333 |
+
if args.slg_layers is not None:
|
334 |
+
args.slg_layers = list(map(int, args.slg_layers.split(",")))
|
335 |
+
|
336 |
+
return args
|
337 |
+
|
338 |
+
|
339 |
+
def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
|
340 |
+
"""Validate video size and length
|
341 |
+
|
342 |
+
Args:
|
343 |
+
args: command line arguments
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
Tuple[int, int, int]: (height, width, video_length)
|
347 |
+
"""
|
348 |
+
height = args.video_size[0]
|
349 |
+
width = args.video_size[1]
|
350 |
+
size = f"{width}*{height}"
|
351 |
+
|
352 |
+
if size not in SUPPORTED_SIZES[args.task]:
|
353 |
+
logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.")
|
354 |
+
|
355 |
+
video_length = args.video_length
|
356 |
+
|
357 |
+
if height % 8 != 0 or width % 8 != 0:
|
358 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
359 |
+
|
360 |
+
return height, width, video_length
|
361 |
+
|
362 |
+
|
363 |
+
def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]:
|
364 |
+
"""calculate dimensions for the generation
|
365 |
+
|
366 |
+
Args:
|
367 |
+
video_size: video frame size (height, width)
|
368 |
+
video_length: number of frames in the video
|
369 |
+
config: model configuration
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
Tuple[Tuple[int, int, int, int], int]:
|
373 |
+
((channels, frames, height, width), seq_len)
|
374 |
+
"""
|
375 |
+
height, width = video_size
|
376 |
+
frames = video_length
|
377 |
+
|
378 |
+
# calculate latent space dimensions
|
379 |
+
lat_f = (frames - 1) // config.vae_stride[0] + 1
|
380 |
+
lat_h = height // config.vae_stride[1]
|
381 |
+
lat_w = width // config.vae_stride[2]
|
382 |
+
|
383 |
+
# calculate sequence length
|
384 |
+
seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f)
|
385 |
+
|
386 |
+
return ((16, lat_f, lat_h, lat_w), seq_len)
|
387 |
+
|
388 |
+
|
389 |
+
def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE:
|
390 |
+
"""load VAE model
|
391 |
+
|
392 |
+
Args:
|
393 |
+
args: command line arguments
|
394 |
+
config: model configuration
|
395 |
+
device: device to use
|
396 |
+
dtype: data type for the model
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
WanVAE: loaded VAE model
|
400 |
+
"""
|
401 |
+
vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint)
|
402 |
+
|
403 |
+
logger.info(f"Loading VAE model from {vae_path}")
|
404 |
+
cache_device = torch.device("cpu") if args.vae_cache_cpu else None
|
405 |
+
vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device)
|
406 |
+
return vae
|
407 |
+
|
408 |
+
|
409 |
+
def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel:
|
410 |
+
"""load text encoder (T5) model
|
411 |
+
|
412 |
+
Args:
|
413 |
+
args: command line arguments
|
414 |
+
config: model configuration
|
415 |
+
device: device to use
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
T5EncoderModel: loaded text encoder model
|
419 |
+
"""
|
420 |
+
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint)
|
421 |
+
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer)
|
422 |
+
|
423 |
+
text_encoder = T5EncoderModel(
|
424 |
+
text_len=config.text_len,
|
425 |
+
dtype=config.t5_dtype,
|
426 |
+
device=device,
|
427 |
+
checkpoint_path=checkpoint_path,
|
428 |
+
tokenizer_path=tokenizer_path,
|
429 |
+
weight_path=args.t5,
|
430 |
+
fp8=args.fp8_t5,
|
431 |
+
)
|
432 |
+
|
433 |
+
return text_encoder
|
434 |
+
|
435 |
+
|
436 |
+
def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel:
|
437 |
+
"""load CLIP model (for I2V only)
|
438 |
+
|
439 |
+
Args:
|
440 |
+
args: command line arguments
|
441 |
+
config: model configuration
|
442 |
+
device: device to use
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
CLIPModel: loaded CLIP model
|
446 |
+
"""
|
447 |
+
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint)
|
448 |
+
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer)
|
449 |
+
|
450 |
+
clip = CLIPModel(
|
451 |
+
dtype=config.clip_dtype,
|
452 |
+
device=device,
|
453 |
+
checkpoint_path=checkpoint_path,
|
454 |
+
tokenizer_path=tokenizer_path,
|
455 |
+
weight_path=args.clip,
|
456 |
+
)
|
457 |
+
|
458 |
+
return clip
|
459 |
+
|
460 |
+
|
461 |
+
def load_dit_model(
|
462 |
+
args: argparse.Namespace,
|
463 |
+
config,
|
464 |
+
device: torch.device,
|
465 |
+
dit_dtype: torch.dtype,
|
466 |
+
dit_weight_dtype: Optional[torch.dtype] = None,
|
467 |
+
is_i2v: bool = False,
|
468 |
+
) -> WanModel:
|
469 |
+
"""load DiT model
|
470 |
+
|
471 |
+
Args:
|
472 |
+
args: command line arguments
|
473 |
+
config: model configuration
|
474 |
+
device: device to use
|
475 |
+
dit_dtype: data type for the model
|
476 |
+
dit_weight_dtype: data type for the model weights. None for as-is
|
477 |
+
is_i2v: I2V mode
|
478 |
+
|
479 |
+
Returns:
|
480 |
+
WanModel: loaded DiT model
|
481 |
+
"""
|
482 |
+
loading_device = "cpu"
|
483 |
+
if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled:
|
484 |
+
loading_device = device
|
485 |
+
|
486 |
+
loading_weight_dtype = dit_weight_dtype
|
487 |
+
if args.fp8_scaled or args.lora_weight is not None:
|
488 |
+
loading_weight_dtype = dit_dtype # load as-is
|
489 |
+
|
490 |
+
# do not fp8 optimize because we will merge LoRA weights
|
491 |
+
model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False)
|
492 |
+
|
493 |
+
return model
|
494 |
+
|
495 |
+
|
496 |
+
def merge_lora_weights(
|
497 |
+
lora_module: ModuleType,
|
498 |
+
model: torch.nn.Module,
|
499 |
+
args: argparse.Namespace,
|
500 |
+
device: torch.device,
|
501 |
+
converter: Optional[callable] = None,
|
502 |
+
) -> None:
|
503 |
+
"""merge LoRA weights to the model
|
504 |
+
|
505 |
+
Args:
|
506 |
+
lora_module: LoRA module, e.g. lora_wan
|
507 |
+
model: DiT model
|
508 |
+
args: command line arguments
|
509 |
+
device: device to use
|
510 |
+
converter: Optional callable to convert weights
|
511 |
+
"""
|
512 |
+
if args.lora_weight is None or len(args.lora_weight) == 0:
|
513 |
+
return
|
514 |
+
|
515 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
516 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
517 |
+
lora_multiplier = args.lora_multiplier[i]
|
518 |
+
else:
|
519 |
+
lora_multiplier = 1.0
|
520 |
+
|
521 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
522 |
+
weights_sd = load_file(lora_weight)
|
523 |
+
if converter is not None:
|
524 |
+
weights_sd = converter(weights_sd)
|
525 |
+
|
526 |
+
# apply include/exclude patterns
|
527 |
+
original_key_count = len(weights_sd.keys())
|
528 |
+
if args.include_patterns is not None and len(args.include_patterns) > i:
|
529 |
+
include_pattern = args.include_patterns[i]
|
530 |
+
regex_include = re.compile(include_pattern)
|
531 |
+
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
532 |
+
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
533 |
+
if args.exclude_patterns is not None and len(args.exclude_patterns) > i:
|
534 |
+
original_key_count_ex = len(weights_sd.keys())
|
535 |
+
exclude_pattern = args.exclude_patterns[i]
|
536 |
+
regex_exclude = re.compile(exclude_pattern)
|
537 |
+
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
538 |
+
logger.info(
|
539 |
+
f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}"
|
540 |
+
)
|
541 |
+
if len(weights_sd) != original_key_count:
|
542 |
+
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
543 |
+
remaining_keys.sort()
|
544 |
+
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
545 |
+
if len(weights_sd) == 0:
|
546 |
+
logger.warning(f"No keys left after filtering.")
|
547 |
+
|
548 |
+
if args.lycoris:
|
549 |
+
lycoris_net, _ = create_network_from_weights(
|
550 |
+
multiplier=lora_multiplier,
|
551 |
+
file=None,
|
552 |
+
weights_sd=weights_sd,
|
553 |
+
unet=model,
|
554 |
+
text_encoder=None,
|
555 |
+
vae=None,
|
556 |
+
for_inference=True,
|
557 |
+
)
|
558 |
+
lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device)
|
559 |
+
else:
|
560 |
+
network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True)
|
561 |
+
network.merge_to(None, model, weights_sd, device=device, non_blocking=True)
|
562 |
+
|
563 |
+
synchronize_device(device)
|
564 |
+
logger.info("LoRA weights loaded")
|
565 |
+
|
566 |
+
# save model here before casting to dit_weight_dtype
|
567 |
+
if args.save_merged_model:
|
568 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
569 |
+
mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory
|
570 |
+
logger.info("Merged model saved")
|
571 |
+
|
572 |
+
|
573 |
+
def optimize_model(
|
574 |
+
model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype
|
575 |
+
) -> None:
|
576 |
+
"""optimize the model (FP8 conversion, device move etc.)
|
577 |
+
|
578 |
+
Args:
|
579 |
+
model: dit model
|
580 |
+
args: command line arguments
|
581 |
+
device: device to use
|
582 |
+
dit_dtype: dtype for the model
|
583 |
+
dit_weight_dtype: dtype for the model weights
|
584 |
+
"""
|
585 |
+
if args.fp8_scaled:
|
586 |
+
# load state dict as-is and optimize to fp8
|
587 |
+
state_dict = model.state_dict()
|
588 |
+
|
589 |
+
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
|
590 |
+
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
|
591 |
+
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
|
592 |
+
|
593 |
+
info = model.load_state_dict(state_dict, strict=True, assign=True)
|
594 |
+
logger.info(f"Loaded FP8 optimized weights: {info}")
|
595 |
+
|
596 |
+
if args.blocks_to_swap == 0:
|
597 |
+
model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
|
598 |
+
else:
|
599 |
+
# simple cast to dit_dtype
|
600 |
+
target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
|
601 |
+
target_device = None
|
602 |
+
|
603 |
+
if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
|
604 |
+
logger.info(f"Convert model to {dit_weight_dtype}")
|
605 |
+
target_dtype = dit_weight_dtype
|
606 |
+
|
607 |
+
if args.blocks_to_swap == 0:
|
608 |
+
logger.info(f"Move model to device: {device}")
|
609 |
+
target_device = device
|
610 |
+
|
611 |
+
model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
|
612 |
+
|
613 |
+
if args.compile:
|
614 |
+
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
|
615 |
+
logger.info(
|
616 |
+
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
|
617 |
+
)
|
618 |
+
torch._dynamo.config.cache_size_limit = 32
|
619 |
+
for i in range(len(model.blocks)):
|
620 |
+
model.blocks[i] = torch.compile(
|
621 |
+
model.blocks[i],
|
622 |
+
backend=compile_backend,
|
623 |
+
mode=compile_mode,
|
624 |
+
dynamic=compile_dynamic.lower() in "true",
|
625 |
+
fullgraph=compile_fullgraph.lower() in "true",
|
626 |
+
)
|
627 |
+
|
628 |
+
if args.blocks_to_swap > 0:
|
629 |
+
logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
|
630 |
+
model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
|
631 |
+
model.move_to_device_except_swap_blocks(device)
|
632 |
+
model.prepare_block_swap_before_forward()
|
633 |
+
else:
|
634 |
+
# make sure the model is on the right device
|
635 |
+
model.to(device)
|
636 |
+
|
637 |
+
model.eval().requires_grad_(False)
|
638 |
+
clean_memory_on_device(device)
|
639 |
+
|
640 |
+
|
641 |
+
def prepare_t2v_inputs(
|
642 |
+
args: argparse.Namespace,
|
643 |
+
config,
|
644 |
+
accelerator: Accelerator,
|
645 |
+
device: torch.device,
|
646 |
+
vae: Optional[WanVAE] = None,
|
647 |
+
encoded_context: Optional[Dict] = None,
|
648 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
649 |
+
"""Prepare inputs for T2V
|
650 |
+
|
651 |
+
Args:
|
652 |
+
args: command line arguments
|
653 |
+
config: model configuration
|
654 |
+
accelerator: Accelerator instance
|
655 |
+
device: device to use
|
656 |
+
vae: VAE model for control video encoding
|
657 |
+
encoded_context: Pre-encoded text context
|
658 |
+
|
659 |
+
Returns:
|
660 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
661 |
+
(noise, context, context_null, (arg_c, arg_null))
|
662 |
+
"""
|
663 |
+
# Prepare inputs for T2V
|
664 |
+
# calculate dimensions and sequence length
|
665 |
+
height, width = args.video_size
|
666 |
+
frames = args.video_length
|
667 |
+
(_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config)
|
668 |
+
target_shape = (16, lat_f, lat_h, lat_w)
|
669 |
+
|
670 |
+
# configure negative prompt
|
671 |
+
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
|
672 |
+
|
673 |
+
# set seed
|
674 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
675 |
+
if not args.cpu_noise:
|
676 |
+
seed_g = torch.Generator(device=device)
|
677 |
+
seed_g.manual_seed(seed)
|
678 |
+
else:
|
679 |
+
# ComfyUI compatible noise
|
680 |
+
seed_g = torch.manual_seed(seed)
|
681 |
+
|
682 |
+
if encoded_context is None:
|
683 |
+
# load text encoder
|
684 |
+
text_encoder = load_text_encoder(args, config, device)
|
685 |
+
text_encoder.model.to(device)
|
686 |
+
|
687 |
+
# encode prompt
|
688 |
+
with torch.no_grad():
|
689 |
+
if args.fp8_t5:
|
690 |
+
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
|
691 |
+
context = text_encoder([args.prompt], device)
|
692 |
+
context_null = text_encoder([n_prompt], device)
|
693 |
+
else:
|
694 |
+
context = text_encoder([args.prompt], device)
|
695 |
+
context_null = text_encoder([n_prompt], device)
|
696 |
+
|
697 |
+
# free text encoder and clean memory
|
698 |
+
del text_encoder
|
699 |
+
clean_memory_on_device(device)
|
700 |
+
else:
|
701 |
+
# Use pre-encoded context
|
702 |
+
context = encoded_context["context"]
|
703 |
+
context_null = encoded_context["context_null"]
|
704 |
+
|
705 |
+
# Fun-Control: encode control video to latent space
|
706 |
+
if config.is_fun_control:
|
707 |
+
# TODO use same resizing as for image
|
708 |
+
logger.info(f"Encoding control video to latent space")
|
709 |
+
# C, F, H, W
|
710 |
+
control_video = load_control_video(args.control_path, frames, height, width).to(device)
|
711 |
+
vae.to_device(device)
|
712 |
+
with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
|
713 |
+
control_latent = vae.encode([control_video])[0]
|
714 |
+
y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent
|
715 |
+
vae.to_device("cpu")
|
716 |
+
else:
|
717 |
+
y = None
|
718 |
+
|
719 |
+
# generate noise
|
720 |
+
noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu")
|
721 |
+
noise = noise.to(device)
|
722 |
+
|
723 |
+
# prepare model input arguments
|
724 |
+
arg_c = {"context": context, "seq_len": seq_len}
|
725 |
+
arg_null = {"context": context_null, "seq_len": seq_len}
|
726 |
+
if y is not None:
|
727 |
+
arg_c["y"] = [y]
|
728 |
+
arg_null["y"] = [y]
|
729 |
+
|
730 |
+
return noise, context, context_null, (arg_c, arg_null)
|
731 |
+
|
732 |
+
|
733 |
+
def prepare_i2v_inputs(
|
734 |
+
args: argparse.Namespace,
|
735 |
+
config,
|
736 |
+
accelerator: Accelerator,
|
737 |
+
device: torch.device,
|
738 |
+
vae: WanVAE,
|
739 |
+
encoded_context: Optional[Dict] = None,
|
740 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
741 |
+
"""Prepare inputs for I2V
|
742 |
+
|
743 |
+
Args:
|
744 |
+
args: command line arguments
|
745 |
+
config: model configuration
|
746 |
+
accelerator: Accelerator instance
|
747 |
+
device: device to use
|
748 |
+
vae: VAE model, used for image encoding
|
749 |
+
encoded_context: Pre-encoded text context
|
750 |
+
|
751 |
+
Returns:
|
752 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
753 |
+
(noise, context, context_null, y, (arg_c, arg_null))
|
754 |
+
"""
|
755 |
+
# get video dimensions
|
756 |
+
height, width = args.video_size
|
757 |
+
frames = args.video_length
|
758 |
+
max_area = width * height
|
759 |
+
|
760 |
+
# load image
|
761 |
+
img = Image.open(args.image_path).convert("RGB")
|
762 |
+
|
763 |
+
# convert to numpy
|
764 |
+
img_cv2 = np.array(img) # PIL to numpy
|
765 |
+
|
766 |
+
# convert to tensor (-1 to 1)
|
767 |
+
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
|
768 |
+
|
769 |
+
# end frame image
|
770 |
+
if args.end_image_path is not None:
|
771 |
+
end_img = Image.open(args.end_image_path).convert("RGB")
|
772 |
+
end_img_cv2 = np.array(end_img) # PIL to numpy
|
773 |
+
else:
|
774 |
+
end_img = None
|
775 |
+
end_img_cv2 = None
|
776 |
+
has_end_image = end_img is not None
|
777 |
+
|
778 |
+
# calculate latent dimensions: keep aspect ratio
|
779 |
+
height, width = img_tensor.shape[1:]
|
780 |
+
aspect_ratio = height / width
|
781 |
+
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
|
782 |
+
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
|
783 |
+
height = lat_h * config.vae_stride[1]
|
784 |
+
width = lat_w * config.vae_stride[2]
|
785 |
+
lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames
|
786 |
+
max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2])
|
787 |
+
|
788 |
+
# set seed
|
789 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
790 |
+
if not args.cpu_noise:
|
791 |
+
seed_g = torch.Generator(device=device)
|
792 |
+
seed_g.manual_seed(seed)
|
793 |
+
else:
|
794 |
+
# ComfyUI compatible noise
|
795 |
+
seed_g = torch.manual_seed(seed)
|
796 |
+
|
797 |
+
# generate noise
|
798 |
+
noise = torch.randn(
|
799 |
+
16,
|
800 |
+
lat_f + (1 if has_end_image else 0),
|
801 |
+
lat_h,
|
802 |
+
lat_w,
|
803 |
+
dtype=torch.float32,
|
804 |
+
generator=seed_g,
|
805 |
+
device=device if not args.cpu_noise else "cpu",
|
806 |
+
)
|
807 |
+
noise = noise.to(device)
|
808 |
+
|
809 |
+
# configure negative prompt
|
810 |
+
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
|
811 |
+
|
812 |
+
if encoded_context is None:
|
813 |
+
# load text encoder
|
814 |
+
text_encoder = load_text_encoder(args, config, device)
|
815 |
+
text_encoder.model.to(device)
|
816 |
+
|
817 |
+
# encode prompt
|
818 |
+
with torch.no_grad():
|
819 |
+
if args.fp8_t5:
|
820 |
+
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
|
821 |
+
context = text_encoder([args.prompt], device)
|
822 |
+
context_null = text_encoder([n_prompt], device)
|
823 |
+
else:
|
824 |
+
context = text_encoder([args.prompt], device)
|
825 |
+
context_null = text_encoder([n_prompt], device)
|
826 |
+
|
827 |
+
# free text encoder and clean memory
|
828 |
+
del text_encoder
|
829 |
+
clean_memory_on_device(device)
|
830 |
+
|
831 |
+
# load CLIP model
|
832 |
+
clip = load_clip_model(args, config, device)
|
833 |
+
clip.model.to(device)
|
834 |
+
|
835 |
+
# encode image to CLIP context
|
836 |
+
logger.info(f"Encoding image to CLIP context")
|
837 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
838 |
+
clip_context = clip.visual([img_tensor[:, None, :, :]])
|
839 |
+
logger.info(f"Encoding complete")
|
840 |
+
|
841 |
+
# free CLIP model and clean memory
|
842 |
+
del clip
|
843 |
+
clean_memory_on_device(device)
|
844 |
+
else:
|
845 |
+
# Use pre-encoded context
|
846 |
+
context = encoded_context["context"]
|
847 |
+
context_null = encoded_context["context_null"]
|
848 |
+
clip_context = encoded_context["clip_context"]
|
849 |
+
|
850 |
+
# encode image to latent space with VAE
|
851 |
+
logger.info(f"Encoding image to latent space")
|
852 |
+
vae.to_device(device)
|
853 |
+
|
854 |
+
# resize image
|
855 |
+
interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC
|
856 |
+
img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation)
|
857 |
+
img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
|
858 |
+
img_resized = img_resized.unsqueeze(1) # CFHW
|
859 |
+
|
860 |
+
if has_end_image:
|
861 |
+
interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC
|
862 |
+
end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation)
|
863 |
+
end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
|
864 |
+
end_img_resized = end_img_resized.unsqueeze(1) # CFHW
|
865 |
+
|
866 |
+
# create mask for the first frame
|
867 |
+
msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device)
|
868 |
+
msk[:, 0] = 1
|
869 |
+
if has_end_image:
|
870 |
+
msk[:, -1] = 1
|
871 |
+
|
872 |
+
# encode image to latent space
|
873 |
+
with accelerator.autocast(), torch.no_grad():
|
874 |
+
# padding to match the required number of frames
|
875 |
+
padding_frames = frames - 1 # the first frame is image
|
876 |
+
img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1)
|
877 |
+
y = vae.encode([img_resized])[0]
|
878 |
+
|
879 |
+
if has_end_image:
|
880 |
+
y_end = vae.encode([end_img_resized])[0]
|
881 |
+
y = torch.concat([y, y_end], dim=1) # add end frame
|
882 |
+
|
883 |
+
y = torch.concat([msk, y])
|
884 |
+
logger.info(f"Encoding complete")
|
885 |
+
|
886 |
+
# Fun-Control: encode control video to latent space
|
887 |
+
if config.is_fun_control:
|
888 |
+
# TODO use same resizing as for image
|
889 |
+
logger.info(f"Encoding control video to latent space")
|
890 |
+
# C, F, H, W
|
891 |
+
control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device)
|
892 |
+
with accelerator.autocast(), torch.no_grad():
|
893 |
+
control_latent = vae.encode([control_video])[0]
|
894 |
+
y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it
|
895 |
+
if has_end_image:
|
896 |
+
y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work
|
897 |
+
else:
|
898 |
+
y[:, 1:] = 0 # remove image latent except first frame
|
899 |
+
y = torch.concat([control_latent, y], dim=0) # add control video latent
|
900 |
+
|
901 |
+
# prepare model input arguments
|
902 |
+
arg_c = {
|
903 |
+
"context": [context[0]],
|
904 |
+
"clip_fea": clip_context,
|
905 |
+
"seq_len": max_seq_len,
|
906 |
+
"y": [y],
|
907 |
+
}
|
908 |
+
|
909 |
+
arg_null = {
|
910 |
+
"context": context_null,
|
911 |
+
"clip_fea": clip_context,
|
912 |
+
"seq_len": max_seq_len,
|
913 |
+
"y": [y],
|
914 |
+
}
|
915 |
+
|
916 |
+
vae.to_device("cpu") # move VAE to CPU to save memory
|
917 |
+
clean_memory_on_device(device)
|
918 |
+
|
919 |
+
return noise, context, context_null, y, (arg_c, arg_null)
|
920 |
+
|
921 |
+
|
922 |
+
def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor:
|
923 |
+
"""load control video to latent space
|
924 |
+
|
925 |
+
Args:
|
926 |
+
control_path: path to control video
|
927 |
+
frames: number of frames in the video
|
928 |
+
height: height of the video
|
929 |
+
width: width of the video
|
930 |
+
|
931 |
+
Returns:
|
932 |
+
torch.Tensor: control video latent, CFHW
|
933 |
+
"""
|
934 |
+
logger.info(f"Load control video from {control_path}")
|
935 |
+
video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames
|
936 |
+
if len(video) < frames:
|
937 |
+
raise ValueError(f"Video length is less than {frames}")
|
938 |
+
# video = np.stack(video, axis=0) # F, H, W, C
|
939 |
+
video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1
|
940 |
+
video = video.permute(1, 0, 2, 3) # C, F, H, W
|
941 |
+
return video
|
942 |
+
|
943 |
+
|
944 |
+
def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
|
945 |
+
"""setup scheduler for sampling
|
946 |
+
|
947 |
+
Args:
|
948 |
+
args: command line arguments
|
949 |
+
config: model configuration
|
950 |
+
device: device to use
|
951 |
+
|
952 |
+
Returns:
|
953 |
+
Tuple[Any, torch.Tensor]: (scheduler, timesteps)
|
954 |
+
"""
|
955 |
+
if args.sample_solver == "unipc":
|
956 |
+
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
|
957 |
+
scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
|
958 |
+
timesteps = scheduler.timesteps
|
959 |
+
elif args.sample_solver == "dpm++":
|
960 |
+
scheduler = FlowDPMSolverMultistepScheduler(
|
961 |
+
num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
|
962 |
+
)
|
963 |
+
sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
|
964 |
+
timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
|
965 |
+
elif args.sample_solver == "vanilla":
|
966 |
+
scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
|
967 |
+
scheduler.set_timesteps(args.infer_steps, device=device)
|
968 |
+
timesteps = scheduler.timesteps
|
969 |
+
|
970 |
+
# FlowMatchDiscreteScheduler does not support generator argument in step method
|
971 |
+
org_step = scheduler.step
|
972 |
+
|
973 |
+
def step_wrapper(
|
974 |
+
model_output: torch.Tensor,
|
975 |
+
timestep: Union[int, torch.Tensor],
|
976 |
+
sample: torch.Tensor,
|
977 |
+
return_dict: bool = True,
|
978 |
+
generator=None,
|
979 |
+
):
|
980 |
+
return org_step(model_output, timestep, sample, return_dict=return_dict)
|
981 |
+
|
982 |
+
scheduler.step = step_wrapper
|
983 |
+
else:
|
984 |
+
raise NotImplementedError("Unsupported solver.")
|
985 |
+
|
986 |
+
return scheduler, timesteps
|
987 |
+
|
988 |
+
|
989 |
+
def run_sampling(
|
990 |
+
model: WanModel,
|
991 |
+
noise: torch.Tensor,
|
992 |
+
scheduler: Any,
|
993 |
+
timesteps: torch.Tensor,
|
994 |
+
args: argparse.Namespace,
|
995 |
+
inputs: Tuple[dict, dict],
|
996 |
+
device: torch.device,
|
997 |
+
seed_g: torch.Generator,
|
998 |
+
accelerator: Accelerator,
|
999 |
+
is_i2v: bool = False,
|
1000 |
+
use_cpu_offload: bool = True,
|
1001 |
+
) -> torch.Tensor:
|
1002 |
+
"""run sampling
|
1003 |
+
Args:
|
1004 |
+
model: dit model
|
1005 |
+
noise: initial noise
|
1006 |
+
scheduler: scheduler for sampling
|
1007 |
+
timesteps: time steps for sampling
|
1008 |
+
args: command line arguments
|
1009 |
+
inputs: model input (arg_c, arg_null)
|
1010 |
+
device: device to use
|
1011 |
+
seed_g: random generator
|
1012 |
+
accelerator: Accelerator instance
|
1013 |
+
is_i2v: I2V mode (False means T2V mode)
|
1014 |
+
use_cpu_offload: Whether to offload tensors to CPU during processing
|
1015 |
+
Returns:
|
1016 |
+
torch.Tensor: generated latent
|
1017 |
+
"""
|
1018 |
+
arg_c, arg_null = inputs
|
1019 |
+
|
1020 |
+
latent = noise
|
1021 |
+
latent_storage_device = device if not use_cpu_offload else "cpu"
|
1022 |
+
latent = latent.to(latent_storage_device)
|
1023 |
+
|
1024 |
+
# cfg skip
|
1025 |
+
apply_cfg_array = []
|
1026 |
+
num_timesteps = len(timesteps)
|
1027 |
+
|
1028 |
+
if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None:
|
1029 |
+
# Calculate thresholds based on cfg_apply_ratio
|
1030 |
+
apply_steps = int(num_timesteps * args.cfg_apply_ratio)
|
1031 |
+
|
1032 |
+
if args.cfg_skip_mode == "early":
|
1033 |
+
# Skip CFG in early steps, apply in late steps
|
1034 |
+
start_index = num_timesteps - apply_steps
|
1035 |
+
end_index = num_timesteps
|
1036 |
+
elif args.cfg_skip_mode == "late":
|
1037 |
+
# Skip CFG in late steps, apply in early steps
|
1038 |
+
start_index = 0
|
1039 |
+
end_index = apply_steps
|
1040 |
+
elif args.cfg_skip_mode == "early_late":
|
1041 |
+
# Skip CFG in early and late steps, apply in middle steps
|
1042 |
+
start_index = (num_timesteps - apply_steps) // 2
|
1043 |
+
end_index = start_index + apply_steps
|
1044 |
+
elif args.cfg_skip_mode == "middle":
|
1045 |
+
# Skip CFG in middle steps, apply in early and late steps
|
1046 |
+
skip_steps = num_timesteps - apply_steps
|
1047 |
+
middle_start = (num_timesteps - skip_steps) // 2
|
1048 |
+
middle_end = middle_start + skip_steps
|
1049 |
+
|
1050 |
+
w = 0.0
|
1051 |
+
for step_idx in range(num_timesteps):
|
1052 |
+
if args.cfg_skip_mode == "alternate":
|
1053 |
+
# accumulate w and apply CFG when w >= 1.0
|
1054 |
+
w += args.cfg_apply_ratio
|
1055 |
+
apply = w >= 1.0
|
1056 |
+
if apply:
|
1057 |
+
w -= 1.0
|
1058 |
+
elif args.cfg_skip_mode == "middle":
|
1059 |
+
# Skip CFG in early and late steps, apply in middle steps
|
1060 |
+
apply = step_idx < middle_start or step_idx >= middle_end
|
1061 |
+
else:
|
1062 |
+
# Apply CFG on some steps based on ratio
|
1063 |
+
apply = step_idx >= start_index and step_idx < end_index
|
1064 |
+
|
1065 |
+
apply_cfg_array.append(apply)
|
1066 |
+
|
1067 |
+
pattern = ["A" if apply else "S" for apply in apply_cfg_array]
|
1068 |
+
pattern = "".join(pattern)
|
1069 |
+
logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}")
|
1070 |
+
else:
|
1071 |
+
# Apply CFG on all steps
|
1072 |
+
apply_cfg_array = [True] * num_timesteps
|
1073 |
+
|
1074 |
+
# SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py
|
1075 |
+
slg_start_step = int(args.slg_start * num_timesteps)
|
1076 |
+
slg_end_step = int(args.slg_end * num_timesteps)
|
1077 |
+
|
1078 |
+
for i, t in enumerate(tqdm(timesteps)):
|
1079 |
+
# latent is on CPU if use_cpu_offload is True
|
1080 |
+
latent_model_input = [latent.to(device)]
|
1081 |
+
timestep = torch.stack([t]).to(device)
|
1082 |
+
|
1083 |
+
with accelerator.autocast(), torch.no_grad():
|
1084 |
+
noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device)
|
1085 |
+
|
1086 |
+
apply_cfg = apply_cfg_array[i] # apply CFG or not
|
1087 |
+
if apply_cfg:
|
1088 |
+
apply_slg = i >= slg_start_step and i < slg_end_step
|
1089 |
+
# print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}")
|
1090 |
+
if args.slg_mode == "original" and apply_slg:
|
1091 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
|
1092 |
+
|
1093 |
+
# apply guidance
|
1094 |
+
# SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale
|
1095 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1096 |
+
|
1097 |
+
# calculate skip layer out
|
1098 |
+
skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
|
1099 |
+
latent_storage_device
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
# apply skip layer guidance
|
1103 |
+
# SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg
|
1104 |
+
noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out)
|
1105 |
+
elif args.slg_mode == "uncond" and apply_slg:
|
1106 |
+
# noise_pred_uncond is skip layer out
|
1107 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
|
1108 |
+
latent_storage_device
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
# apply guidance
|
1112 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1113 |
+
|
1114 |
+
else:
|
1115 |
+
# normal guidance
|
1116 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
|
1117 |
+
|
1118 |
+
# apply guidance
|
1119 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1120 |
+
else:
|
1121 |
+
noise_pred = noise_pred_cond
|
1122 |
+
|
1123 |
+
# step
|
1124 |
+
latent_input = latent.unsqueeze(0)
|
1125 |
+
temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0]
|
1126 |
+
|
1127 |
+
# update latent
|
1128 |
+
latent = temp_x0.squeeze(0)
|
1129 |
+
|
1130 |
+
return latent
|
1131 |
+
|
1132 |
+
|
1133 |
+
def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor:
|
1134 |
+
"""main function for generation
|
1135 |
+
|
1136 |
+
Args:
|
1137 |
+
args: command line arguments
|
1138 |
+
shared_models: dictionary containing pre-loaded models and encoded data
|
1139 |
+
|
1140 |
+
Returns:
|
1141 |
+
torch.Tensor: generated latent
|
1142 |
+
"""
|
1143 |
+
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
|
1144 |
+
gen_settings.device,
|
1145 |
+
gen_settings.cfg,
|
1146 |
+
gen_settings.dit_dtype,
|
1147 |
+
gen_settings.dit_weight_dtype,
|
1148 |
+
gen_settings.vae_dtype,
|
1149 |
+
)
|
1150 |
+
|
1151 |
+
# prepare accelerator
|
1152 |
+
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
|
1153 |
+
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
|
1154 |
+
|
1155 |
+
# I2V or T2V
|
1156 |
+
is_i2v = "i2v" in args.task
|
1157 |
+
|
1158 |
+
# prepare seed
|
1159 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
1160 |
+
args.seed = seed # set seed to args for saving
|
1161 |
+
|
1162 |
+
# Check if we have shared models
|
1163 |
+
if shared_models is not None:
|
1164 |
+
# Use shared models and encoded data
|
1165 |
+
vae = shared_models.get("vae")
|
1166 |
+
model = shared_models.get("model")
|
1167 |
+
encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt)
|
1168 |
+
|
1169 |
+
# prepare inputs
|
1170 |
+
if is_i2v:
|
1171 |
+
# I2V
|
1172 |
+
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
|
1173 |
+
else:
|
1174 |
+
# T2V
|
1175 |
+
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
|
1176 |
+
else:
|
1177 |
+
# prepare inputs without shared models
|
1178 |
+
if is_i2v:
|
1179 |
+
# I2V: need text encoder, VAE and CLIP
|
1180 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1181 |
+
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae)
|
1182 |
+
# vae is on CPU after prepare_i2v_inputs
|
1183 |
+
else:
|
1184 |
+
# T2V: need text encoder
|
1185 |
+
vae = None
|
1186 |
+
if cfg.is_fun_control:
|
1187 |
+
# Fun-Control: need VAE for encoding control video
|
1188 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1189 |
+
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae)
|
1190 |
+
|
1191 |
+
# load DiT model
|
1192 |
+
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
|
1193 |
+
|
1194 |
+
# merge LoRA weights
|
1195 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
1196 |
+
merge_lora_weights(lora_wan, model, args, device)
|
1197 |
+
|
1198 |
+
# if we only want to save the model, we can skip the rest
|
1199 |
+
if args.save_merged_model:
|
1200 |
+
return None
|
1201 |
+
|
1202 |
+
# optimize model: fp8 conversion, block swap etc.
|
1203 |
+
optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
|
1204 |
+
|
1205 |
+
# setup scheduler
|
1206 |
+
scheduler, timesteps = setup_scheduler(args, cfg, device)
|
1207 |
+
|
1208 |
+
# set random generator
|
1209 |
+
seed_g = torch.Generator(device=device)
|
1210 |
+
seed_g.manual_seed(seed)
|
1211 |
+
|
1212 |
+
# run sampling
|
1213 |
+
latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v)
|
1214 |
+
|
1215 |
+
# Only clean up shared models if they were created within this function
|
1216 |
+
if shared_models is None:
|
1217 |
+
# free memory
|
1218 |
+
del model
|
1219 |
+
del scheduler
|
1220 |
+
synchronize_device(device)
|
1221 |
+
|
1222 |
+
# wait for 5 seconds until block swap is done
|
1223 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
1224 |
+
time.sleep(5)
|
1225 |
+
|
1226 |
+
gc.collect()
|
1227 |
+
clean_memory_on_device(device)
|
1228 |
+
|
1229 |
+
# save VAE model for decoding
|
1230 |
+
if vae is None:
|
1231 |
+
args._vae = None
|
1232 |
+
else:
|
1233 |
+
args._vae = vae
|
1234 |
+
|
1235 |
+
return latent
|
1236 |
+
|
1237 |
+
|
1238 |
+
def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor:
|
1239 |
+
"""decode latent
|
1240 |
+
|
1241 |
+
Args:
|
1242 |
+
latent: latent tensor
|
1243 |
+
args: command line arguments
|
1244 |
+
cfg: model configuration
|
1245 |
+
|
1246 |
+
Returns:
|
1247 |
+
torch.Tensor: decoded video or image
|
1248 |
+
"""
|
1249 |
+
device = torch.device(args.device)
|
1250 |
+
|
1251 |
+
# load VAE model or use the one from the generation
|
1252 |
+
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16
|
1253 |
+
if hasattr(args, "_vae") and args._vae is not None:
|
1254 |
+
vae = args._vae
|
1255 |
+
else:
|
1256 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1257 |
+
|
1258 |
+
vae.to_device(device)
|
1259 |
+
|
1260 |
+
logger.info(f"Decoding video from latents: {latent.shape}")
|
1261 |
+
x0 = latent.to(device)
|
1262 |
+
|
1263 |
+
with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad():
|
1264 |
+
videos = vae.decode(x0)
|
1265 |
+
|
1266 |
+
# some tail frames may be corrupted when end frame is used, we add an option to remove them
|
1267 |
+
if args.trim_tail_frames:
|
1268 |
+
videos[0] = videos[0][:, : -args.trim_tail_frames]
|
1269 |
+
|
1270 |
+
logger.info(f"Decoding complete")
|
1271 |
+
video = videos[0]
|
1272 |
+
del videos
|
1273 |
+
video = video.to(torch.float32).cpu()
|
1274 |
+
|
1275 |
+
return video
|
1276 |
+
|
1277 |
+
|
1278 |
+
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
|
1279 |
+
"""Save latent to file
|
1280 |
+
|
1281 |
+
Args:
|
1282 |
+
latent: latent tensor
|
1283 |
+
args: command line arguments
|
1284 |
+
height: height of frame
|
1285 |
+
width: width of frame
|
1286 |
+
|
1287 |
+
Returns:
|
1288 |
+
str: Path to saved latent file
|
1289 |
+
"""
|
1290 |
+
save_path = args.save_path
|
1291 |
+
os.makedirs(save_path, exist_ok=True)
|
1292 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
1293 |
+
|
1294 |
+
seed = args.seed
|
1295 |
+
video_length = args.video_length
|
1296 |
+
latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
|
1297 |
+
|
1298 |
+
if args.no_metadata:
|
1299 |
+
metadata = None
|
1300 |
+
else:
|
1301 |
+
metadata = {
|
1302 |
+
"seeds": f"{seed}",
|
1303 |
+
"prompt": f"{args.prompt}",
|
1304 |
+
"height": f"{height}",
|
1305 |
+
"width": f"{width}",
|
1306 |
+
"video_length": f"{video_length}",
|
1307 |
+
"infer_steps": f"{args.infer_steps}",
|
1308 |
+
"guidance_scale": f"{args.guidance_scale}",
|
1309 |
+
}
|
1310 |
+
if args.negative_prompt is not None:
|
1311 |
+
metadata["negative_prompt"] = f"{args.negative_prompt}"
|
1312 |
+
|
1313 |
+
sd = {"latent": latent}
|
1314 |
+
save_file(sd, latent_path, metadata=metadata)
|
1315 |
+
logger.info(f"Latent saved to: {latent_path}")
|
1316 |
+
|
1317 |
+
return latent_path
|
1318 |
+
|
1319 |
+
|
1320 |
+
def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
|
1321 |
+
"""Save video to file
|
1322 |
+
|
1323 |
+
Args:
|
1324 |
+
video: Video tensor
|
1325 |
+
args: command line arguments
|
1326 |
+
original_base_name: Original base name (if latents are loaded from files)
|
1327 |
+
|
1328 |
+
Returns:
|
1329 |
+
str: Path to saved video file
|
1330 |
+
"""
|
1331 |
+
save_path = args.save_path
|
1332 |
+
os.makedirs(save_path, exist_ok=True)
|
1333 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
1334 |
+
|
1335 |
+
seed = args.seed
|
1336 |
+
original_name = "" if original_base_name is None else f"_{original_base_name}"
|
1337 |
+
video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4"
|
1338 |
+
|
1339 |
+
video = video.unsqueeze(0)
|
1340 |
+
save_videos_grid(video, video_path, fps=args.fps, rescale=True)
|
1341 |
+
logger.info(f"Video saved to: {video_path}")
|
1342 |
+
|
1343 |
+
return video_path
|
1344 |
+
|
1345 |
+
|
1346 |
+
def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
|
1347 |
+
"""Save images to directory
|
1348 |
+
|
1349 |
+
Args:
|
1350 |
+
sample: Video tensor
|
1351 |
+
args: command line arguments
|
1352 |
+
original_base_name: Original base name (if latents are loaded from files)
|
1353 |
+
|
1354 |
+
Returns:
|
1355 |
+
str: Path to saved images directory
|
1356 |
+
"""
|
1357 |
+
save_path = args.save_path
|
1358 |
+
os.makedirs(save_path, exist_ok=True)
|
1359 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
1360 |
+
|
1361 |
+
seed = args.seed
|
1362 |
+
original_name = "" if original_base_name is None else f"_{original_base_name}"
|
1363 |
+
image_name = f"{time_flag}_{seed}{original_name}"
|
1364 |
+
sample = sample.unsqueeze(0)
|
1365 |
+
save_images_grid(sample, save_path, image_name, rescale=True)
|
1366 |
+
logger.info(f"Sample images saved to: {save_path}/{image_name}")
|
1367 |
+
|
1368 |
+
return f"{save_path}/{image_name}"
|
1369 |
+
|
1370 |
+
|
1371 |
+
def save_output(
|
1372 |
+
latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None
|
1373 |
+
) -> None:
|
1374 |
+
"""save output
|
1375 |
+
|
1376 |
+
Args:
|
1377 |
+
latent: latent tensor
|
1378 |
+
args: command line arguments
|
1379 |
+
cfg: model configuration
|
1380 |
+
height: height of frame
|
1381 |
+
width: width of frame
|
1382 |
+
original_base_names: original base names (if latents are loaded from files)
|
1383 |
+
"""
|
1384 |
+
if args.output_type == "latent" or args.output_type == "both":
|
1385 |
+
# save latent
|
1386 |
+
save_latent(latent, args, height, width)
|
1387 |
+
|
1388 |
+
if args.output_type == "video" or args.output_type == "both":
|
1389 |
+
# save video
|
1390 |
+
sample = decode_latent(latent.unsqueeze(0), args, cfg)
|
1391 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
|
1392 |
+
save_video(sample, args, original_name)
|
1393 |
+
|
1394 |
+
elif args.output_type == "images":
|
1395 |
+
# save images
|
1396 |
+
sample = decode_latent(latent.unsqueeze(0), args, cfg)
|
1397 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
|
1398 |
+
save_images(sample, args, original_name)
|
1399 |
+
|
1400 |
+
|
1401 |
+
def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
|
1402 |
+
"""Process multiple prompts for batch mode
|
1403 |
+
|
1404 |
+
Args:
|
1405 |
+
prompt_lines: List of prompt lines
|
1406 |
+
base_args: Base command line arguments
|
1407 |
+
|
1408 |
+
Returns:
|
1409 |
+
List[Dict]: List of prompt data dictionaries
|
1410 |
+
"""
|
1411 |
+
prompts_data = []
|
1412 |
+
|
1413 |
+
for line in prompt_lines:
|
1414 |
+
line = line.strip()
|
1415 |
+
if not line or line.startswith("#"): # Skip empty lines and comments
|
1416 |
+
continue
|
1417 |
+
|
1418 |
+
# Parse prompt line and create override dictionary
|
1419 |
+
prompt_data = parse_prompt_line(line)
|
1420 |
+
logger.info(f"Parsed prompt data: {prompt_data}")
|
1421 |
+
prompts_data.append(prompt_data)
|
1422 |
+
|
1423 |
+
return prompts_data
|
1424 |
+
|
1425 |
+
|
1426 |
+
def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
|
1427 |
+
"""Process multiple prompts with model reuse
|
1428 |
+
|
1429 |
+
Args:
|
1430 |
+
prompts_data: List of prompt data dictionaries
|
1431 |
+
args: Base command line arguments
|
1432 |
+
"""
|
1433 |
+
if not prompts_data:
|
1434 |
+
logger.warning("No valid prompts found")
|
1435 |
+
return
|
1436 |
+
|
1437 |
+
# 1. Load configuration
|
1438 |
+
gen_settings = get_generation_settings(args)
|
1439 |
+
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
|
1440 |
+
gen_settings.device,
|
1441 |
+
gen_settings.cfg,
|
1442 |
+
gen_settings.dit_dtype,
|
1443 |
+
gen_settings.dit_weight_dtype,
|
1444 |
+
gen_settings.vae_dtype,
|
1445 |
+
)
|
1446 |
+
is_i2v = "i2v" in args.task
|
1447 |
+
|
1448 |
+
# 2. Encode all prompts
|
1449 |
+
logger.info("Loading text encoder to encode all prompts")
|
1450 |
+
text_encoder = load_text_encoder(args, cfg, device)
|
1451 |
+
text_encoder.model.to(device)
|
1452 |
+
|
1453 |
+
encoded_contexts = {}
|
1454 |
+
|
1455 |
+
with torch.no_grad():
|
1456 |
+
for prompt_data in prompts_data:
|
1457 |
+
prompt = prompt_data["prompt"]
|
1458 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1459 |
+
n_prompt = prompt_data.get(
|
1460 |
+
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
|
1461 |
+
)
|
1462 |
+
|
1463 |
+
if args.fp8_t5:
|
1464 |
+
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
|
1465 |
+
context = text_encoder([prompt], device)
|
1466 |
+
context_null = text_encoder([n_prompt], device)
|
1467 |
+
else:
|
1468 |
+
context = text_encoder([prompt], device)
|
1469 |
+
context_null = text_encoder([n_prompt], device)
|
1470 |
+
|
1471 |
+
encoded_contexts[prompt] = {"context": context, "context_null": context_null}
|
1472 |
+
|
1473 |
+
# Free text encoder and clean memory
|
1474 |
+
del text_encoder
|
1475 |
+
clean_memory_on_device(device)
|
1476 |
+
|
1477 |
+
# 3. Process I2V additional encodings if needed
|
1478 |
+
vae = None
|
1479 |
+
if is_i2v:
|
1480 |
+
logger.info("Loading VAE and CLIP for I2V preprocessing")
|
1481 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1482 |
+
vae.to_device(device)
|
1483 |
+
|
1484 |
+
clip = load_clip_model(args, cfg, device)
|
1485 |
+
clip.model.to(device)
|
1486 |
+
|
1487 |
+
# Process each image and encode with CLIP
|
1488 |
+
for prompt_data in prompts_data:
|
1489 |
+
if "image_path" not in prompt_data:
|
1490 |
+
continue
|
1491 |
+
|
1492 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1493 |
+
if not os.path.exists(prompt_args.image_path):
|
1494 |
+
logger.warning(f"Image path not found: {prompt_args.image_path}")
|
1495 |
+
continue
|
1496 |
+
|
1497 |
+
# Load and encode image with CLIP
|
1498 |
+
img = Image.open(prompt_args.image_path).convert("RGB")
|
1499 |
+
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
|
1500 |
+
|
1501 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
1502 |
+
clip_context = clip.visual([img_tensor[:, None, :, :]])
|
1503 |
+
|
1504 |
+
encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context
|
1505 |
+
|
1506 |
+
# Free CLIP and clean memory
|
1507 |
+
del clip
|
1508 |
+
clean_memory_on_device(device)
|
1509 |
+
|
1510 |
+
# Keep VAE in CPU memory for later use
|
1511 |
+
vae.to_device("cpu")
|
1512 |
+
elif cfg.is_fun_control:
|
1513 |
+
# For Fun-Control, we need VAE but keep it on CPU
|
1514 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1515 |
+
vae.to_device("cpu")
|
1516 |
+
|
1517 |
+
# 4. Load DiT model
|
1518 |
+
logger.info("Loading DiT model")
|
1519 |
+
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
|
1520 |
+
|
1521 |
+
# 5. Merge LoRA weights if needed
|
1522 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
1523 |
+
merge_lora_weights(lora_wan, model, args, device)
|
1524 |
+
if args.save_merged_model:
|
1525 |
+
logger.info("Model merged and saved. Exiting.")
|
1526 |
+
return
|
1527 |
+
|
1528 |
+
# 6. Optimize model
|
1529 |
+
optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
|
1530 |
+
|
1531 |
+
# Create shared models dict for generate function
|
1532 |
+
shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts}
|
1533 |
+
|
1534 |
+
# 7. Generate for each prompt
|
1535 |
+
all_latents = []
|
1536 |
+
all_prompt_args = []
|
1537 |
+
|
1538 |
+
for i, prompt_data in enumerate(prompts_data):
|
1539 |
+
logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...")
|
1540 |
+
|
1541 |
+
# Apply overrides for this prompt
|
1542 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1543 |
+
|
1544 |
+
# Generate latent
|
1545 |
+
latent = generate(prompt_args, gen_settings, shared_models)
|
1546 |
+
|
1547 |
+
# Save latent if needed
|
1548 |
+
height, width, _ = check_inputs(prompt_args)
|
1549 |
+
if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
|
1550 |
+
save_latent(latent, prompt_args, height, width)
|
1551 |
+
|
1552 |
+
all_latents.append(latent)
|
1553 |
+
all_prompt_args.append(prompt_args)
|
1554 |
+
|
1555 |
+
# 8. Free DiT model
|
1556 |
+
del model
|
1557 |
+
clean_memory_on_device(device)
|
1558 |
+
synchronize_device(device)
|
1559 |
+
|
1560 |
+
# wait for 5 seconds until block swap is done
|
1561 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
1562 |
+
time.sleep(5)
|
1563 |
+
|
1564 |
+
gc.collect()
|
1565 |
+
clean_memory_on_device(device)
|
1566 |
+
|
1567 |
+
# 9. Decode latents if needed
|
1568 |
+
if args.output_type != "latent":
|
1569 |
+
logger.info("Decoding latents to videos/images")
|
1570 |
+
|
1571 |
+
if vae is None:
|
1572 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1573 |
+
|
1574 |
+
vae.to_device(device)
|
1575 |
+
|
1576 |
+
for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
|
1577 |
+
logger.info(f"Decoding output {i+1}/{len(all_latents)}")
|
1578 |
+
|
1579 |
+
# Decode latent
|
1580 |
+
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
|
1581 |
+
|
1582 |
+
# Save as video or images
|
1583 |
+
if prompt_args.output_type == "video" or prompt_args.output_type == "both":
|
1584 |
+
save_video(video, prompt_args)
|
1585 |
+
elif prompt_args.output_type == "images":
|
1586 |
+
save_images(video, prompt_args)
|
1587 |
+
|
1588 |
+
# Free VAE
|
1589 |
+
del vae
|
1590 |
+
|
1591 |
+
clean_memory_on_device(device)
|
1592 |
+
gc.collect()
|
1593 |
+
|
1594 |
+
|
1595 |
+
def process_interactive(args: argparse.Namespace) -> None:
|
1596 |
+
"""Process prompts in interactive mode
|
1597 |
+
|
1598 |
+
Args:
|
1599 |
+
args: Base command line arguments
|
1600 |
+
"""
|
1601 |
+
gen_settings = get_generation_settings(args)
|
1602 |
+
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
|
1603 |
+
gen_settings.device,
|
1604 |
+
gen_settings.cfg,
|
1605 |
+
gen_settings.dit_dtype,
|
1606 |
+
gen_settings.dit_weight_dtype,
|
1607 |
+
gen_settings.vae_dtype,
|
1608 |
+
)
|
1609 |
+
is_i2v = "i2v" in args.task
|
1610 |
+
|
1611 |
+
# Initialize models to None
|
1612 |
+
text_encoder = None
|
1613 |
+
vae = None
|
1614 |
+
model = None
|
1615 |
+
clip = None
|
1616 |
+
|
1617 |
+
print("Interactive mode. Enter prompts (Ctrl+D to exit):")
|
1618 |
+
|
1619 |
+
try:
|
1620 |
+
while True:
|
1621 |
+
try:
|
1622 |
+
line = input("> ")
|
1623 |
+
if not line.strip():
|
1624 |
+
continue
|
1625 |
+
|
1626 |
+
# Parse prompt
|
1627 |
+
prompt_data = parse_prompt_line(line)
|
1628 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1629 |
+
|
1630 |
+
# Ensure we have all the models we need
|
1631 |
+
|
1632 |
+
# 1. Load text encoder if not already loaded
|
1633 |
+
if text_encoder is None:
|
1634 |
+
logger.info("Loading text encoder")
|
1635 |
+
text_encoder = load_text_encoder(args, cfg, device)
|
1636 |
+
|
1637 |
+
text_encoder.model.to(device)
|
1638 |
+
|
1639 |
+
# Encode prompt
|
1640 |
+
n_prompt = prompt_data.get(
|
1641 |
+
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
|
1642 |
+
)
|
1643 |
+
|
1644 |
+
with torch.no_grad():
|
1645 |
+
if args.fp8_t5:
|
1646 |
+
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
|
1647 |
+
context = text_encoder([prompt_data["prompt"]], device)
|
1648 |
+
context_null = text_encoder([n_prompt], device)
|
1649 |
+
else:
|
1650 |
+
context = text_encoder([prompt_data["prompt"]], device)
|
1651 |
+
context_null = text_encoder([n_prompt], device)
|
1652 |
+
|
1653 |
+
encoded_context = {"context": context, "context_null": context_null}
|
1654 |
+
|
1655 |
+
# Move text encoder to CPU after use
|
1656 |
+
text_encoder.model.to("cpu")
|
1657 |
+
|
1658 |
+
# 2. For I2V, we need CLIP and VAE
|
1659 |
+
if is_i2v:
|
1660 |
+
if clip is None:
|
1661 |
+
logger.info("Loading CLIP model")
|
1662 |
+
clip = load_clip_model(args, cfg, device)
|
1663 |
+
|
1664 |
+
clip.model.to(device)
|
1665 |
+
|
1666 |
+
# Encode image with CLIP if there's an image path
|
1667 |
+
if prompt_args.image_path and os.path.exists(prompt_args.image_path):
|
1668 |
+
img = Image.open(prompt_args.image_path).convert("RGB")
|
1669 |
+
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
|
1670 |
+
|
1671 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
1672 |
+
clip_context = clip.visual([img_tensor[:, None, :, :]])
|
1673 |
+
|
1674 |
+
encoded_context["clip_context"] = clip_context
|
1675 |
+
|
1676 |
+
# Move CLIP to CPU after use
|
1677 |
+
clip.model.to("cpu")
|
1678 |
+
|
1679 |
+
# Load VAE if needed
|
1680 |
+
if vae is None:
|
1681 |
+
logger.info("Loading VAE model")
|
1682 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1683 |
+
elif cfg.is_fun_control and vae is None:
|
1684 |
+
# For Fun-Control, we need VAE
|
1685 |
+
logger.info("Loading VAE model for Fun-Control")
|
1686 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1687 |
+
|
1688 |
+
# 3. Load DiT model if not already loaded
|
1689 |
+
if model is None:
|
1690 |
+
logger.info("Loading DiT model")
|
1691 |
+
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
|
1692 |
+
|
1693 |
+
# Merge LoRA weights if needed
|
1694 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
1695 |
+
merge_lora_weights(lora_wan, model, args, device)
|
1696 |
+
|
1697 |
+
# Optimize model
|
1698 |
+
optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
|
1699 |
+
else:
|
1700 |
+
# Move model to GPU if it was offloaded
|
1701 |
+
model.to(device)
|
1702 |
+
|
1703 |
+
# Create shared models dict
|
1704 |
+
shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}}
|
1705 |
+
|
1706 |
+
# Generate latent
|
1707 |
+
latent = generate(prompt_args, gen_settings, shared_models)
|
1708 |
+
|
1709 |
+
# Move model to CPU after generation
|
1710 |
+
model.to("cpu")
|
1711 |
+
|
1712 |
+
# Save latent if needed
|
1713 |
+
height, width, _ = check_inputs(prompt_args)
|
1714 |
+
if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
|
1715 |
+
save_latent(latent, prompt_args, height, width)
|
1716 |
+
|
1717 |
+
# Decode and save output
|
1718 |
+
if prompt_args.output_type != "latent":
|
1719 |
+
if vae is None:
|
1720 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1721 |
+
|
1722 |
+
vae.to_device(device)
|
1723 |
+
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
|
1724 |
+
|
1725 |
+
if prompt_args.output_type == "video" or prompt_args.output_type == "both":
|
1726 |
+
save_video(video, prompt_args)
|
1727 |
+
elif prompt_args.output_type == "images":
|
1728 |
+
save_images(video, prompt_args)
|
1729 |
+
|
1730 |
+
# Move VAE to CPU after use
|
1731 |
+
vae.to_device("cpu")
|
1732 |
+
|
1733 |
+
clean_memory_on_device(device)
|
1734 |
+
|
1735 |
+
except KeyboardInterrupt:
|
1736 |
+
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
|
1737 |
+
continue
|
1738 |
+
|
1739 |
+
except EOFError:
|
1740 |
+
print("\nExiting interactive mode")
|
1741 |
+
|
1742 |
+
# Clean up all models
|
1743 |
+
if text_encoder is not None:
|
1744 |
+
del text_encoder
|
1745 |
+
if clip is not None:
|
1746 |
+
del clip
|
1747 |
+
if vae is not None:
|
1748 |
+
del vae
|
1749 |
+
if model is not None:
|
1750 |
+
del model
|
1751 |
+
|
1752 |
+
clean_memory_on_device(device)
|
1753 |
+
gc.collect()
|
1754 |
+
|
1755 |
+
|
1756 |
+
def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
|
1757 |
+
device = torch.device(args.device)
|
1758 |
+
|
1759 |
+
cfg = WAN_CONFIGS[args.task]
|
1760 |
+
|
1761 |
+
# select dtype
|
1762 |
+
dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16
|
1763 |
+
if dit_dtype.itemsize == 1:
|
1764 |
+
# if weight is in fp8, use bfloat16 for DiT (input/output)
|
1765 |
+
dit_dtype = torch.bfloat16
|
1766 |
+
if args.fp8_scaled:
|
1767 |
+
raise ValueError(
|
1768 |
+
"DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
|
1769 |
+
)
|
1770 |
+
|
1771 |
+
dit_weight_dtype = dit_dtype # default
|
1772 |
+
if args.fp8_scaled:
|
1773 |
+
dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
|
1774 |
+
elif args.fp8:
|
1775 |
+
dit_weight_dtype = torch.float8_e4m3fn
|
1776 |
+
|
1777 |
+
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype
|
1778 |
+
logger.info(
|
1779 |
+
f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}"
|
1780 |
+
)
|
1781 |
+
|
1782 |
+
gen_settings = GenerationSettings(
|
1783 |
+
device=device,
|
1784 |
+
cfg=cfg,
|
1785 |
+
dit_dtype=dit_dtype,
|
1786 |
+
dit_weight_dtype=dit_weight_dtype,
|
1787 |
+
vae_dtype=vae_dtype,
|
1788 |
+
)
|
1789 |
+
return gen_settings
|
1790 |
+
|
1791 |
+
|
1792 |
+
def main():
|
1793 |
+
# Parse arguments
|
1794 |
+
args = parse_args()
|
1795 |
+
|
1796 |
+
# Check if latents are provided
|
1797 |
+
latents_mode = args.latent_path is not None and len(args.latent_path) > 0
|
1798 |
+
|
1799 |
+
# Set device
|
1800 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
1801 |
+
device = torch.device(device)
|
1802 |
+
logger.info(f"Using device: {device}")
|
1803 |
+
args.device = device
|
1804 |
+
|
1805 |
+
if latents_mode:
|
1806 |
+
# Original latent decode mode
|
1807 |
+
cfg = WAN_CONFIGS[args.task] # any task is fine
|
1808 |
+
original_base_names = []
|
1809 |
+
latents_list = []
|
1810 |
+
seeds = []
|
1811 |
+
|
1812 |
+
assert len(args.latent_path) == 1, "Only one latent path is supported for now"
|
1813 |
+
|
1814 |
+
for latent_path in args.latent_path:
|
1815 |
+
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
|
1816 |
+
seed = 0
|
1817 |
+
|
1818 |
+
if os.path.splitext(latent_path)[1] != ".safetensors":
|
1819 |
+
latents = torch.load(latent_path, map_location="cpu")
|
1820 |
+
else:
|
1821 |
+
latents = load_file(latent_path)["latent"]
|
1822 |
+
with safe_open(latent_path, framework="pt") as f:
|
1823 |
+
metadata = f.metadata()
|
1824 |
+
if metadata is None:
|
1825 |
+
metadata = {}
|
1826 |
+
logger.info(f"Loaded metadata: {metadata}")
|
1827 |
+
|
1828 |
+
if "seeds" in metadata:
|
1829 |
+
seed = int(metadata["seeds"])
|
1830 |
+
if "height" in metadata and "width" in metadata:
|
1831 |
+
height = int(metadata["height"])
|
1832 |
+
width = int(metadata["width"])
|
1833 |
+
args.video_size = [height, width]
|
1834 |
+
if "video_length" in metadata:
|
1835 |
+
args.video_length = int(metadata["video_length"])
|
1836 |
+
|
1837 |
+
seeds.append(seed)
|
1838 |
+
latents_list.append(latents)
|
1839 |
+
|
1840 |
+
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
|
1841 |
+
|
1842 |
+
latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
|
1843 |
+
|
1844 |
+
height = latents.shape[-2]
|
1845 |
+
width = latents.shape[-1]
|
1846 |
+
height *= cfg.patch_size[1] * cfg.vae_stride[1]
|
1847 |
+
width *= cfg.patch_size[2] * cfg.vae_stride[2]
|
1848 |
+
video_length = latents.shape[1]
|
1849 |
+
video_length = (video_length - 1) * cfg.vae_stride[0] + 1
|
1850 |
+
args.seed = seeds[0]
|
1851 |
+
|
1852 |
+
# Decode and save
|
1853 |
+
save_output(latent[0], args, cfg, height, width, original_base_names)
|
1854 |
+
|
1855 |
+
elif args.from_file:
|
1856 |
+
# Batch mode from file
|
1857 |
+
args = setup_args(args)
|
1858 |
+
|
1859 |
+
# Read prompts from file
|
1860 |
+
with open(args.from_file, "r", encoding="utf-8") as f:
|
1861 |
+
prompt_lines = f.readlines()
|
1862 |
+
|
1863 |
+
# Process prompts
|
1864 |
+
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
|
1865 |
+
process_batch_prompts(prompts_data, args)
|
1866 |
+
|
1867 |
+
elif args.interactive:
|
1868 |
+
# Interactive mode
|
1869 |
+
args = setup_args(args)
|
1870 |
+
process_interactive(args)
|
1871 |
+
|
1872 |
+
else:
|
1873 |
+
# Single prompt mode (original behavior)
|
1874 |
+
args = setup_args(args)
|
1875 |
+
height, width, video_length = check_inputs(args)
|
1876 |
+
|
1877 |
+
logger.info(
|
1878 |
+
f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, "
|
1879 |
+
f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}"
|
1880 |
+
)
|
1881 |
+
|
1882 |
+
# Generate latent
|
1883 |
+
gen_settings = get_generation_settings(args)
|
1884 |
+
latent = generate(args, gen_settings)
|
1885 |
+
|
1886 |
+
# Make sure the model is freed from GPU memory
|
1887 |
+
gc.collect()
|
1888 |
+
clean_memory_on_device(args.device)
|
1889 |
+
|
1890 |
+
# Save latent and video
|
1891 |
+
if args.save_merged_model:
|
1892 |
+
return
|
1893 |
+
|
1894 |
+
# Add batch dimension
|
1895 |
+
latent = latent.unsqueeze(0)
|
1896 |
+
save_output(latent[0], args, WAN_CONFIGS[args.task], height, width)
|
1897 |
+
|
1898 |
+
logger.info("Done!")
|
1899 |
+
|
1900 |
+
|
1901 |
+
if __name__ == "__main__":
|
1902 |
+
main()
|
wan_train_network.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from typing import Optional
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as TF
|
9 |
+
from tqdm import tqdm
|
10 |
+
from accelerate import Accelerator, init_empty_weights
|
11 |
+
|
12 |
+
from dataset.image_video_dataset import ARCHITECTURE_WAN, ARCHITECTURE_WAN_FULL, load_video
|
13 |
+
from hv_generate_video import resize_image_to_bucket
|
14 |
+
from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
|
15 |
+
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
logging.basicConfig(level=logging.INFO)
|
20 |
+
|
21 |
+
from utils import model_utils
|
22 |
+
from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
|
23 |
+
from wan.configs import WAN_CONFIGS
|
24 |
+
from wan.modules.clip import CLIPModel
|
25 |
+
from wan.modules.model import WanModel, detect_wan_sd_dtype, load_wan_model
|
26 |
+
from wan.modules.t5 import T5EncoderModel
|
27 |
+
from wan.modules.vae import WanVAE
|
28 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
29 |
+
|
30 |
+
|
31 |
+
class WanNetworkTrainer(NetworkTrainer):
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
# region model specific
|
36 |
+
|
37 |
+
@property
|
38 |
+
def architecture(self) -> str:
|
39 |
+
return ARCHITECTURE_WAN
|
40 |
+
|
41 |
+
@property
|
42 |
+
def architecture_full_name(self) -> str:
|
43 |
+
return ARCHITECTURE_WAN_FULL
|
44 |
+
|
45 |
+
def handle_model_specific_args(self, args):
|
46 |
+
self.config = WAN_CONFIGS[args.task]
|
47 |
+
self._i2v_training = "i2v" in args.task # we cannot use config.i2v because Fun-Control T2V has i2v flag TODO refactor this
|
48 |
+
self._control_training = self.config.is_fun_control
|
49 |
+
|
50 |
+
self.dit_dtype = detect_wan_sd_dtype(args.dit)
|
51 |
+
|
52 |
+
if self.dit_dtype == torch.float16:
|
53 |
+
assert args.mixed_precision in ["fp16", "no"], "DiT weights are in fp16, mixed precision must be fp16 or no"
|
54 |
+
elif self.dit_dtype == torch.bfloat16:
|
55 |
+
assert args.mixed_precision in ["bf16", "no"], "DiT weights are in bf16, mixed precision must be bf16 or no"
|
56 |
+
|
57 |
+
if args.fp8_scaled and self.dit_dtype.itemsize == 1:
|
58 |
+
raise ValueError(
|
59 |
+
"DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
|
60 |
+
)
|
61 |
+
|
62 |
+
# dit_dtype cannot be fp8, so we select the appropriate dtype
|
63 |
+
if self.dit_dtype.itemsize == 1:
|
64 |
+
self.dit_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
65 |
+
|
66 |
+
args.dit_dtype = model_utils.dtype_to_str(self.dit_dtype)
|
67 |
+
|
68 |
+
self.default_guidance_scale = 1.0 # not used
|
69 |
+
|
70 |
+
def process_sample_prompts(
|
71 |
+
self,
|
72 |
+
args: argparse.Namespace,
|
73 |
+
accelerator: Accelerator,
|
74 |
+
sample_prompts: str,
|
75 |
+
):
|
76 |
+
config = self.config
|
77 |
+
device = accelerator.device
|
78 |
+
t5_path, clip_path, fp8_t5 = args.t5, args.clip, args.fp8_t5
|
79 |
+
|
80 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
|
81 |
+
prompts = load_prompts(sample_prompts)
|
82 |
+
|
83 |
+
def encode_for_text_encoder(text_encoder):
|
84 |
+
sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
|
85 |
+
# with accelerator.autocast(), torch.no_grad(): # this causes NaN if dit_dtype is fp16
|
86 |
+
t5_dtype = config.t5_dtype
|
87 |
+
with torch.amp.autocast(device_type=device.type, dtype=t5_dtype), torch.no_grad():
|
88 |
+
for prompt_dict in prompts:
|
89 |
+
if "negative_prompt" not in prompt_dict:
|
90 |
+
prompt_dict["negative_prompt"] = self.config["sample_neg_prompt"]
|
91 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", None)]:
|
92 |
+
if p is None:
|
93 |
+
continue
|
94 |
+
if p not in sample_prompts_te_outputs:
|
95 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
96 |
+
|
97 |
+
prompt_outputs = text_encoder([p], device)
|
98 |
+
sample_prompts_te_outputs[p] = prompt_outputs
|
99 |
+
|
100 |
+
return sample_prompts_te_outputs
|
101 |
+
|
102 |
+
# Load Text Encoder 1 and encode
|
103 |
+
logger.info(f"loading T5: {t5_path}")
|
104 |
+
t5 = T5EncoderModel(text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=t5_path, fp8=fp8_t5)
|
105 |
+
|
106 |
+
logger.info("encoding with Text Encoder 1")
|
107 |
+
te_outputs_1 = encode_for_text_encoder(t5)
|
108 |
+
del t5
|
109 |
+
|
110 |
+
# load CLIP and encode image (for I2V training)
|
111 |
+
# Note: VAE encoding is done in do_inference() for I2V training, because we have VAE in the pipeline. Control video is also done in do_inference()
|
112 |
+
sample_prompts_image_embs = {}
|
113 |
+
for prompt_dict in prompts:
|
114 |
+
if prompt_dict.get("image_path", None) is not None and self.i2v_training:
|
115 |
+
sample_prompts_image_embs[prompt_dict["image_path"]] = None # this will be replaced with CLIP context
|
116 |
+
|
117 |
+
if len(sample_prompts_image_embs) > 0:
|
118 |
+
logger.info(f"loading CLIP: {clip_path}")
|
119 |
+
assert clip_path is not None, "CLIP path is required for I2V training / I2V学習にはCLIPのパスが必要です"
|
120 |
+
clip = CLIPModel(dtype=config.clip_dtype, device=device, weight_path=clip_path)
|
121 |
+
clip.model.to(device)
|
122 |
+
|
123 |
+
logger.info(f"Encoding image to CLIP context")
|
124 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
125 |
+
for image_path in sample_prompts_image_embs:
|
126 |
+
logger.info(f"Encoding image: {image_path}")
|
127 |
+
img = Image.open(image_path).convert("RGB")
|
128 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) # -1 to 1
|
129 |
+
clip_context = clip.visual([img[:, None, :, :]])
|
130 |
+
sample_prompts_image_embs[image_path] = clip_context
|
131 |
+
|
132 |
+
del clip
|
133 |
+
clean_memory_on_device(device)
|
134 |
+
|
135 |
+
# prepare sample parameters
|
136 |
+
sample_parameters = []
|
137 |
+
for prompt_dict in prompts:
|
138 |
+
prompt_dict_copy = prompt_dict.copy()
|
139 |
+
|
140 |
+
p = prompt_dict.get("prompt", "")
|
141 |
+
prompt_dict_copy["t5_embeds"] = te_outputs_1[p][0]
|
142 |
+
|
143 |
+
p = prompt_dict.get("negative_prompt", None)
|
144 |
+
if p is not None:
|
145 |
+
prompt_dict_copy["negative_t5_embeds"] = te_outputs_1[p][0]
|
146 |
+
|
147 |
+
p = prompt_dict.get("image_path", None)
|
148 |
+
if p is not None and self.i2v_training:
|
149 |
+
prompt_dict_copy["clip_embeds"] = sample_prompts_image_embs[p]
|
150 |
+
|
151 |
+
sample_parameters.append(prompt_dict_copy)
|
152 |
+
|
153 |
+
clean_memory_on_device(accelerator.device)
|
154 |
+
|
155 |
+
return sample_parameters
|
156 |
+
|
157 |
+
def do_inference(
|
158 |
+
self,
|
159 |
+
accelerator,
|
160 |
+
args,
|
161 |
+
sample_parameter,
|
162 |
+
vae,
|
163 |
+
dit_dtype,
|
164 |
+
transformer,
|
165 |
+
discrete_flow_shift,
|
166 |
+
sample_steps,
|
167 |
+
width,
|
168 |
+
height,
|
169 |
+
frame_count,
|
170 |
+
generator,
|
171 |
+
do_classifier_free_guidance,
|
172 |
+
guidance_scale,
|
173 |
+
cfg_scale,
|
174 |
+
image_path=None,
|
175 |
+
control_video_path=None,
|
176 |
+
):
|
177 |
+
"""architecture dependent inference"""
|
178 |
+
model: WanModel = transformer
|
179 |
+
device = accelerator.device
|
180 |
+
if cfg_scale is None:
|
181 |
+
cfg_scale = 5.0
|
182 |
+
do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
|
183 |
+
|
184 |
+
# Calculate latent video length based on VAE version
|
185 |
+
latent_video_length = (frame_count - 1) // self.config["vae_stride"][0] + 1
|
186 |
+
|
187 |
+
# Get embeddings
|
188 |
+
context = sample_parameter["t5_embeds"].to(device=device)
|
189 |
+
if do_classifier_free_guidance:
|
190 |
+
context_null = sample_parameter["negative_t5_embeds"].to(device=device)
|
191 |
+
else:
|
192 |
+
context_null = None
|
193 |
+
|
194 |
+
num_channels_latents = 16 # model.in_dim
|
195 |
+
vae_scale_factor = self.config["vae_stride"][1]
|
196 |
+
|
197 |
+
# Initialize latents
|
198 |
+
lat_h = height // vae_scale_factor
|
199 |
+
lat_w = width // vae_scale_factor
|
200 |
+
shape_or_frame = (1, num_channels_latents, 1, lat_h, lat_w)
|
201 |
+
latents = []
|
202 |
+
for _ in range(latent_video_length):
|
203 |
+
latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=torch.float32))
|
204 |
+
latents = torch.cat(latents, dim=2)
|
205 |
+
|
206 |
+
image_latents = None
|
207 |
+
if self.i2v_training or self.control_training:
|
208 |
+
# Move VAE to the appropriate device for sampling: consider to cache image latents in CPU in advance
|
209 |
+
vae.to(device)
|
210 |
+
vae.eval()
|
211 |
+
|
212 |
+
if self.i2v_training:
|
213 |
+
image = Image.open(image_path)
|
214 |
+
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
|
215 |
+
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).float() # C, 1, H, W
|
216 |
+
image = image / 127.5 - 1 # -1 to 1
|
217 |
+
|
218 |
+
# Create mask for the required number of frames
|
219 |
+
msk = torch.ones(1, frame_count, lat_h, lat_w, device=device)
|
220 |
+
msk[:, 1:] = 0
|
221 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
222 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
223 |
+
msk = msk.transpose(1, 2) # B, C, T, H, W
|
224 |
+
|
225 |
+
with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
|
226 |
+
# Zero padding for the required number of frames only
|
227 |
+
padding_frames = frame_count - 1 # The first frame is the input image
|
228 |
+
image = torch.concat([image, torch.zeros(3, padding_frames, height, width)], dim=1).to(device=device)
|
229 |
+
y = vae.encode([image])[0]
|
230 |
+
|
231 |
+
y = y[:, :latent_video_length] # may be not needed
|
232 |
+
y = y.unsqueeze(0) # add batch dim
|
233 |
+
image_latents = torch.concat([msk, y], dim=1)
|
234 |
+
|
235 |
+
if self.control_training:
|
236 |
+
# Control video
|
237 |
+
video = load_video(control_video_path, 0, frame_count, bucket_reso=(width, height)) # list of frames
|
238 |
+
video = np.stack(video, axis=0) # F, H, W, C
|
239 |
+
video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # C, F, H, W
|
240 |
+
video = video / 127.5 - 1 # -1 to 1
|
241 |
+
video = video.to(device=device)
|
242 |
+
|
243 |
+
with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
|
244 |
+
control_latents = vae.encode([video])[0]
|
245 |
+
control_latents = control_latents[:, :latent_video_length]
|
246 |
+
control_latents = control_latents.unsqueeze(0) # add batch dim
|
247 |
+
|
248 |
+
# We supports Wan2.1-Fun-Control only
|
249 |
+
if image_latents is not None:
|
250 |
+
image_latents = image_latents[:, 4:] # remove mask for Wan2.1-Fun-Control
|
251 |
+
image_latents[:, :, 1:] = 0 # remove except the first frame
|
252 |
+
else:
|
253 |
+
image_latents = torch.zeros_like(control_latents) # B, C, F, H, W
|
254 |
+
|
255 |
+
image_latents = torch.concat([control_latents, image_latents], dim=1) # B, C, F, H, W
|
256 |
+
|
257 |
+
vae.to("cpu")
|
258 |
+
clean_memory_on_device(device)
|
259 |
+
|
260 |
+
# use the default value for num_train_timesteps (1000)
|
261 |
+
scheduler = FlowUniPCMultistepScheduler(shift=1, use_dynamic_shifting=False)
|
262 |
+
scheduler.set_timesteps(sample_steps, device=device, shift=discrete_flow_shift)
|
263 |
+
timesteps = scheduler.timesteps
|
264 |
+
|
265 |
+
# Generate noise for the required number of frames only
|
266 |
+
noise = torch.randn(16, latent_video_length, lat_h, lat_w, dtype=torch.float32, generator=generator, device=device).to(
|
267 |
+
"cpu"
|
268 |
+
)
|
269 |
+
|
270 |
+
# prepare the model input
|
271 |
+
max_seq_len = latent_video_length * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
|
272 |
+
arg_c = {"context": [context], "seq_len": max_seq_len}
|
273 |
+
arg_null = {"context": [context_null], "seq_len": max_seq_len}
|
274 |
+
|
275 |
+
if self.i2v_training:
|
276 |
+
arg_c["clip_fea"] = sample_parameter["clip_embeds"].to(device=device, dtype=dit_dtype)
|
277 |
+
arg_null["clip_fea"] = arg_c["clip_fea"]
|
278 |
+
if self.i2v_training or self.control_training:
|
279 |
+
arg_c["y"] = image_latents
|
280 |
+
arg_null["y"] = image_latents
|
281 |
+
|
282 |
+
# Wrap the inner loop with tqdm to track progress over timesteps
|
283 |
+
prompt_idx = sample_parameter.get("enum", 0)
|
284 |
+
latent = noise
|
285 |
+
with torch.no_grad():
|
286 |
+
for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")):
|
287 |
+
latent_model_input = [latent.to(device=device)]
|
288 |
+
timestep = t.unsqueeze(0)
|
289 |
+
|
290 |
+
with accelerator.autocast():
|
291 |
+
noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to("cpu")
|
292 |
+
if do_classifier_free_guidance:
|
293 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to("cpu")
|
294 |
+
else:
|
295 |
+
noise_pred_uncond = None
|
296 |
+
|
297 |
+
if do_classifier_free_guidance:
|
298 |
+
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
299 |
+
else:
|
300 |
+
noise_pred = noise_pred_cond
|
301 |
+
|
302 |
+
temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator)[0]
|
303 |
+
latent = temp_x0.squeeze(0)
|
304 |
+
|
305 |
+
# Move VAE to the appropriate device for sampling
|
306 |
+
vae.to(device)
|
307 |
+
vae.eval()
|
308 |
+
|
309 |
+
# Decode latents to video
|
310 |
+
logger.info(f"Decoding video from latents: {latent.shape}")
|
311 |
+
latent = latent.unsqueeze(0) # add batch dim
|
312 |
+
latent = latent.to(device=device)
|
313 |
+
|
314 |
+
with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
|
315 |
+
video = vae.decode(latent)[0] # vae returns list
|
316 |
+
video = video.unsqueeze(0) # add batch dim
|
317 |
+
del latent
|
318 |
+
|
319 |
+
logger.info(f"Decoding complete")
|
320 |
+
video = video.to(torch.float32).cpu()
|
321 |
+
video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
|
322 |
+
|
323 |
+
vae.to("cpu")
|
324 |
+
clean_memory_on_device(device)
|
325 |
+
|
326 |
+
return video
|
327 |
+
|
328 |
+
def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
|
329 |
+
vae_path = args.vae
|
330 |
+
|
331 |
+
logger.info(f"Loading VAE model from {vae_path}")
|
332 |
+
cache_device = torch.device("cpu") if args.vae_cache_cpu else None
|
333 |
+
vae = WanVAE(vae_path=vae_path, device="cpu", dtype=vae_dtype, cache_device=cache_device)
|
334 |
+
return vae
|
335 |
+
|
336 |
+
def load_transformer(
|
337 |
+
self,
|
338 |
+
accelerator: Accelerator,
|
339 |
+
args: argparse.Namespace,
|
340 |
+
dit_path: str,
|
341 |
+
attn_mode: str,
|
342 |
+
split_attn: bool,
|
343 |
+
loading_device: str,
|
344 |
+
dit_weight_dtype: Optional[torch.dtype],
|
345 |
+
):
|
346 |
+
model = load_wan_model(
|
347 |
+
self.config, accelerator.device, dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.fp8_scaled
|
348 |
+
)
|
349 |
+
return model
|
350 |
+
|
351 |
+
def scale_shift_latents(self, latents):
|
352 |
+
return latents
|
353 |
+
|
354 |
+
def call_dit(
|
355 |
+
self,
|
356 |
+
args: argparse.Namespace,
|
357 |
+
accelerator: Accelerator,
|
358 |
+
transformer,
|
359 |
+
latents: torch.Tensor,
|
360 |
+
batch: dict[str, torch.Tensor],
|
361 |
+
noise: torch.Tensor,
|
362 |
+
noisy_model_input: torch.Tensor,
|
363 |
+
timesteps: torch.Tensor,
|
364 |
+
network_dtype: torch.dtype,
|
365 |
+
):
|
366 |
+
model: WanModel = transformer
|
367 |
+
|
368 |
+
# I2V training and Control training
|
369 |
+
image_latents = None
|
370 |
+
clip_fea = None
|
371 |
+
if self.i2v_training:
|
372 |
+
image_latents = batch["latents_image"]
|
373 |
+
image_latents = image_latents.to(device=accelerator.device, dtype=network_dtype)
|
374 |
+
clip_fea = batch["clip"]
|
375 |
+
clip_fea = clip_fea.to(device=accelerator.device, dtype=network_dtype)
|
376 |
+
if self.control_training:
|
377 |
+
control_latents = batch["latents_control"]
|
378 |
+
control_latents = control_latents.to(device=accelerator.device, dtype=network_dtype)
|
379 |
+
if image_latents is not None:
|
380 |
+
image_latents = image_latents[:, 4:] # remove mask for Wan2.1-Fun-Control
|
381 |
+
image_latents[:, :, 1:] = 0 # remove except the first frame
|
382 |
+
else:
|
383 |
+
image_latents = torch.zeros_like(control_latents) # B, C, F, H, W
|
384 |
+
image_latents = torch.concat([control_latents, image_latents], dim=1) # B, C, F, H, W
|
385 |
+
control_latents = None
|
386 |
+
|
387 |
+
context = [t.to(device=accelerator.device, dtype=network_dtype) for t in batch["t5"]]
|
388 |
+
|
389 |
+
# ensure the hidden state will require grad
|
390 |
+
if args.gradient_checkpointing:
|
391 |
+
noisy_model_input.requires_grad_(True)
|
392 |
+
for t in context:
|
393 |
+
t.requires_grad_(True)
|
394 |
+
if image_latents is not None:
|
395 |
+
image_latents.requires_grad_(True)
|
396 |
+
if clip_fea is not None:
|
397 |
+
clip_fea.requires_grad_(True)
|
398 |
+
|
399 |
+
# call DiT
|
400 |
+
lat_f, lat_h, lat_w = latents.shape[2:5]
|
401 |
+
seq_len = lat_f * lat_h * lat_w // (self.config.patch_size[0] * self.config.patch_size[1] * self.config.patch_size[2])
|
402 |
+
latents = latents.to(device=accelerator.device, dtype=network_dtype)
|
403 |
+
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
|
404 |
+
with accelerator.autocast():
|
405 |
+
model_pred = model(noisy_model_input, t=timesteps, context=context, clip_fea=clip_fea, seq_len=seq_len, y=image_latents)
|
406 |
+
model_pred = torch.stack(model_pred, dim=0) # list to tensor
|
407 |
+
|
408 |
+
# flow matching loss
|
409 |
+
target = noise - latents
|
410 |
+
|
411 |
+
return model_pred, target
|
412 |
+
|
413 |
+
# endregion model specific
|
414 |
+
|
415 |
+
|
416 |
+
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
417 |
+
"""Wan2.1 specific parser setup"""
|
418 |
+
parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
|
419 |
+
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
420 |
+
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
|
421 |
+
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
|
422 |
+
parser.add_argument(
|
423 |
+
"--clip",
|
424 |
+
type=str,
|
425 |
+
default=None,
|
426 |
+
help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
|
427 |
+
)
|
428 |
+
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
|
429 |
+
return parser
|
430 |
+
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
parser = setup_parser_common()
|
434 |
+
parser = wan_setup_parser(parser)
|
435 |
+
|
436 |
+
args = parser.parse_args()
|
437 |
+
args = read_config_from_file(args, parser)
|
438 |
+
|
439 |
+
args.dit_dtype = None # automatically detected
|
440 |
+
if args.vae_dtype is None:
|
441 |
+
args.vae_dtype = "bfloat16" # make bfloat16 as default for VAE
|
442 |
+
|
443 |
+
trainer = WanNetworkTrainer()
|
444 |
+
trainer.train(args)
|