Spaces:
Running
Running
from concurrent.futures import ThreadPoolExecutor | |
import glob | |
import json | |
import math | |
import os | |
import random | |
import time | |
from typing import Optional, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
from safetensors.torch import save_file, load_file | |
from safetensors import safe_open | |
from PIL import Image | |
import cv2 | |
import av | |
from utils import safetensors_utils | |
from utils.model_utils import dtype_to_str | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] | |
try: | |
import pillow_avif | |
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) | |
except: | |
pass | |
# JPEG-XL on Linux | |
try: | |
from jxlpy import JXLImagePlugin | |
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) | |
except: | |
pass | |
# JPEG-XL on Windows | |
try: | |
import pillow_jxl | |
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) | |
except: | |
pass | |
VIDEO_EXTENSIONS = [ | |
".mp4", | |
".webm", | |
".avi", | |
".mkv", | |
".mov", | |
".flv", | |
".wmv", | |
".m4v", | |
".mpg", | |
".mpeg", | |
".MP4", | |
".WEBM", | |
".AVI", | |
".MKV", | |
".MOV", | |
".FLV", | |
".WMV", | |
".M4V", | |
".MPG", | |
".MPEG", | |
] # some of them are not tested | |
ARCHITECTURE_HUNYUAN_VIDEO = "hv" | |
def glob_images(directory, base="*"): | |
img_paths = [] | |
for ext in IMAGE_EXTENSIONS: | |
if base == "*": | |
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) | |
else: | |
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) | |
img_paths = list(set(img_paths)) # remove duplicates | |
img_paths.sort() | |
return img_paths | |
def glob_videos(directory, base="*"): | |
video_paths = [] | |
for ext in VIDEO_EXTENSIONS: | |
if base == "*": | |
video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) | |
else: | |
video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) | |
video_paths = list(set(video_paths)) # remove duplicates | |
video_paths.sort() | |
return video_paths | |
def divisible_by(num: int, divisor: int) -> int: | |
return num - num % divisor | |
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: | |
""" | |
Resize the image to the bucket resolution. | |
""" | |
is_pil_image = isinstance(image, Image.Image) | |
if is_pil_image: | |
image_width, image_height = image.size | |
else: | |
image_height, image_width = image.shape[:2] | |
if bucket_reso == (image_width, image_height): | |
return np.array(image) if is_pil_image else image | |
bucket_width, bucket_height = bucket_reso | |
if bucket_width == image_width or bucket_height == image_height: | |
image = np.array(image) if is_pil_image else image | |
else: | |
# resize the image to the bucket resolution to match the short side | |
scale_width = bucket_width / image_width | |
scale_height = bucket_height / image_height | |
scale = max(scale_width, scale_height) | |
image_width = int(image_width * scale + 0.5) | |
image_height = int(image_height * scale + 0.5) | |
if scale > 1: | |
image = Image.fromarray(image) if not is_pil_image else image | |
image = image.resize((image_width, image_height), Image.LANCZOS) | |
image = np.array(image) | |
else: | |
image = np.array(image) if is_pil_image else image | |
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) | |
# crop the image to the bucket resolution | |
crop_left = (image_width - bucket_width) // 2 | |
crop_top = (image_height - bucket_height) // 2 | |
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] | |
return image | |
class ItemInfo: | |
def __init__( | |
self, | |
item_key: str, | |
caption: str, | |
original_size: tuple[int, int], | |
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, | |
frame_count: Optional[int] = None, | |
content: Optional[np.ndarray] = None, | |
latent_cache_path: Optional[str] = None, | |
) -> None: | |
self.item_key = item_key | |
self.caption = caption | |
self.original_size = original_size | |
self.bucket_size = bucket_size | |
self.frame_count = frame_count | |
self.content = content | |
self.latent_cache_path = latent_cache_path | |
self.text_encoder_output_cache_path: Optional[str] = None | |
def __str__(self) -> str: | |
return ( | |
f"ItemInfo(item_key={self.item_key}, caption={self.caption}, " | |
+ f"original_size={self.original_size}, bucket_size={self.bucket_size}, " | |
+ f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})" | |
) | |
def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor): | |
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" | |
# NaN check and show warning, replace NaN with 0 | |
if torch.isnan(latent).any(): | |
logger.warning(f"latent tensor has NaN: {item_info.item_key}, replace NaN with 0") | |
latent[torch.isnan(latent)] = 0 | |
metadata = { | |
"architecture": "hunyuan_video", | |
"width": f"{item_info.original_size[0]}", | |
"height": f"{item_info.original_size[1]}", | |
"format_version": "1.0.0", | |
} | |
if item_info.frame_count is not None: | |
metadata["frame_count"] = f"{item_info.frame_count}" | |
_, F, H, W = latent.shape | |
dtype_str = dtype_to_str(latent.dtype) | |
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} | |
latent_dir = os.path.dirname(item_info.latent_cache_path) | |
os.makedirs(latent_dir, exist_ok=True) | |
save_file(sd, item_info.latent_cache_path, metadata=metadata) | |
def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool): | |
assert ( | |
embed.dim() == 1 or embed.dim() == 2 | |
), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}" | |
assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}" | |
# NaN check and show warning, replace NaN with 0 | |
if torch.isnan(embed).any(): | |
logger.warning(f"embed tensor has NaN: {item_info.item_key}, replace NaN with 0") | |
embed[torch.isnan(embed)] = 0 | |
metadata = { | |
"architecture": "hunyuan_video", | |
"caption1": item_info.caption, | |
"format_version": "1.0.0", | |
} | |
sd = {} | |
if os.path.exists(item_info.text_encoder_output_cache_path): | |
# load existing cache and update metadata | |
with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f: | |
existing_metadata = f.metadata() | |
for key in f.keys(): | |
sd[key] = f.get_tensor(key) | |
assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch" | |
if existing_metadata["caption1"] != metadata["caption1"]: | |
logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite") | |
# TODO verify format_version | |
existing_metadata.pop("caption1", None) | |
existing_metadata.pop("format_version", None) | |
metadata.update(existing_metadata) # copy existing metadata | |
else: | |
text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path) | |
os.makedirs(text_encoder_output_dir, exist_ok=True) | |
dtype_str = dtype_to_str(embed.dtype) | |
text_encoder_type = "llm" if is_llm else "clipL" | |
sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() | |
if mask is not None: | |
sd[f"{text_encoder_type}_mask"] = mask.detach().cpu() | |
safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata) | |
class BucketSelector: | |
RESOLUTION_STEPS_HUNYUAN = 16 | |
def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False): | |
self.resolution = resolution | |
self.bucket_area = resolution[0] * resolution[1] | |
self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN | |
if not enable_bucket: | |
# only define one bucket | |
self.bucket_resolutions = [resolution] | |
self.no_upscale = False | |
else: | |
# prepare bucket resolution | |
self.no_upscale = no_upscale | |
sqrt_size = int(math.sqrt(self.bucket_area)) | |
min_size = divisible_by(sqrt_size // 2, self.reso_steps) | |
self.bucket_resolutions = [] | |
for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): | |
h = divisible_by(self.bucket_area // w, self.reso_steps) | |
self.bucket_resolutions.append((w, h)) | |
self.bucket_resolutions.append((h, w)) | |
self.bucket_resolutions = list(set(self.bucket_resolutions)) | |
self.bucket_resolutions.sort() | |
# calculate aspect ratio to find the nearest resolution | |
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) | |
def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: | |
""" | |
return the bucket resolution for the given image size, (width, height) | |
""" | |
area = image_size[0] * image_size[1] | |
if self.no_upscale and area <= self.bucket_area: | |
w, h = image_size | |
w = divisible_by(w, self.reso_steps) | |
h = divisible_by(h, self.reso_steps) | |
return w, h | |
aspect_ratio = image_size[0] / image_size[1] | |
ar_errors = self.aspect_ratios - aspect_ratio | |
bucket_id = np.abs(ar_errors).argmin() | |
return self.bucket_resolutions[bucket_id] | |
def load_video( | |
video_path: str, | |
start_frame: Optional[int] = None, | |
end_frame: Optional[int] = None, | |
bucket_selector: Optional[BucketSelector] = None, | |
bucket_reso: Optional[tuple[int, int]] = None, | |
) -> list[np.ndarray]: | |
""" | |
bucket_reso: if given, resize the video to the bucket resolution, (width, height) | |
""" | |
container = av.open(video_path) | |
video = [] | |
for i, frame in enumerate(container.decode(video=0)): | |
if start_frame is not None and i < start_frame: | |
continue | |
if end_frame is not None and i >= end_frame: | |
break | |
frame = frame.to_image() | |
if bucket_selector is not None and bucket_reso is None: | |
bucket_reso = bucket_selector.get_bucket_resolution(frame.size) | |
if bucket_reso is not None: | |
frame = resize_image_to_bucket(frame, bucket_reso) | |
else: | |
frame = np.array(frame) | |
video.append(frame) | |
container.close() | |
return video | |
class BucketBatchManager: | |
def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int): | |
self.batch_size = batch_size | |
self.buckets = bucketed_item_info | |
self.bucket_resos = list(self.buckets.keys()) | |
self.bucket_resos.sort() | |
self.bucket_batch_indices = [] | |
for bucket_reso in self.bucket_resos: | |
bucket = self.buckets[bucket_reso] | |
num_batches = math.ceil(len(bucket) / self.batch_size) | |
for i in range(num_batches): | |
self.bucket_batch_indices.append((bucket_reso, i)) | |
self.shuffle() | |
def show_bucket_info(self): | |
for bucket_reso in self.bucket_resos: | |
bucket = self.buckets[bucket_reso] | |
logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}") | |
logger.info(f"total batches: {len(self)}") | |
def shuffle(self): | |
for bucket in self.buckets.values(): | |
random.shuffle(bucket) | |
random.shuffle(self.bucket_batch_indices) | |
def __len__(self): | |
return len(self.bucket_batch_indices) | |
def __getitem__(self, idx): | |
bucket_reso, batch_idx = self.bucket_batch_indices[idx] | |
bucket = self.buckets[bucket_reso] | |
start = batch_idx * self.batch_size | |
end = min(start + self.batch_size, len(bucket)) | |
latents = [] | |
llm_embeds = [] | |
llm_masks = [] | |
clip_l_embeds = [] | |
for item_info in bucket[start:end]: | |
sd = load_file(item_info.latent_cache_path) | |
latent = None | |
for key in sd.keys(): | |
if key.startswith("latents_"): | |
latent = sd[key] | |
break | |
latents.append(latent) | |
sd = load_file(item_info.text_encoder_output_cache_path) | |
llm_embed = llm_mask = clip_l_embed = None | |
for key in sd.keys(): | |
if key.startswith("llm_mask"): | |
llm_mask = sd[key] | |
elif key.startswith("llm_"): | |
llm_embed = sd[key] | |
elif key.startswith("clipL_mask"): | |
pass | |
elif key.startswith("clipL_"): | |
clip_l_embed = sd[key] | |
llm_embeds.append(llm_embed) | |
llm_masks.append(llm_mask) | |
clip_l_embeds.append(clip_l_embed) | |
latents = torch.stack(latents) | |
llm_embeds = torch.stack(llm_embeds) | |
llm_masks = torch.stack(llm_masks) | |
clip_l_embeds = torch.stack(clip_l_embeds) | |
return latents, llm_embeds, llm_masks, clip_l_embeds | |
class ContentDatasource: | |
def __init__(self): | |
self.caption_only = False | |
def set_caption_only(self, caption_only: bool): | |
self.caption_only = caption_only | |
def is_indexable(self): | |
return False | |
def get_caption(self, idx: int) -> tuple[str, str]: | |
""" | |
Returns caption. May not be called if is_indexable() returns False. | |
""" | |
raise NotImplementedError | |
def __len__(self): | |
raise NotImplementedError | |
def __iter__(self): | |
raise NotImplementedError | |
def __next__(self): | |
raise NotImplementedError | |
class ImageDatasource(ContentDatasource): | |
def __init__(self): | |
super().__init__() | |
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: | |
""" | |
Returns image data as a tuple of image path, image, and caption for the given index. | |
Key must be unique and valid as a file name. | |
May not be called if is_indexable() returns False. | |
""" | |
raise NotImplementedError | |
class ImageDirectoryDatasource(ImageDatasource): | |
def __init__(self, image_directory: str, caption_extension: Optional[str] = None): | |
super().__init__() | |
self.image_directory = image_directory | |
self.caption_extension = caption_extension | |
self.current_idx = 0 | |
# glob images | |
logger.info(f"glob images in {self.image_directory}") | |
self.image_paths = glob_images(self.image_directory) | |
logger.info(f"found {len(self.image_paths)} images") | |
def is_indexable(self): | |
return True | |
def __len__(self): | |
return len(self.image_paths) | |
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: | |
image_path = self.image_paths[idx] | |
image = Image.open(image_path).convert("RGB") | |
_, caption = self.get_caption(idx) | |
return image_path, image, caption | |
def get_caption(self, idx: int) -> tuple[str, str]: | |
image_path = self.image_paths[idx] | |
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else "" | |
with open(caption_path, "r", encoding="utf-8") as f: | |
caption = f.read().strip() | |
return image_path, caption | |
def __iter__(self): | |
self.current_idx = 0 | |
return self | |
def __next__(self) -> callable: | |
""" | |
Returns a fetcher function that returns image data. | |
""" | |
if self.current_idx >= len(self.image_paths): | |
raise StopIteration | |
if self.caption_only: | |
def create_caption_fetcher(index): | |
return lambda: self.get_caption(index) | |
fetcher = create_caption_fetcher(self.current_idx) | |
else: | |
def create_image_fetcher(index): | |
return lambda: self.get_image_data(index) | |
fetcher = create_image_fetcher(self.current_idx) | |
self.current_idx += 1 | |
return fetcher | |
class ImageJsonlDatasource(ImageDatasource): | |
def __init__(self, image_jsonl_file: str): | |
super().__init__() | |
self.image_jsonl_file = image_jsonl_file | |
self.current_idx = 0 | |
# load jsonl | |
logger.info(f"load image jsonl from {self.image_jsonl_file}") | |
self.data = [] | |
with open(self.image_jsonl_file, "r", encoding="utf-8") as f: | |
for line in f: | |
try: | |
data = json.loads(line) | |
except json.JSONDecodeError: | |
logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}") | |
raise | |
self.data.append(data) | |
logger.info(f"loaded {len(self.data)} images") | |
def is_indexable(self): | |
return True | |
def __len__(self): | |
return len(self.data) | |
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: | |
data = self.data[idx] | |
image_path = data["image_path"] | |
image = Image.open(image_path).convert("RGB") | |
caption = data["caption"] | |
return image_path, image, caption | |
def get_caption(self, idx: int) -> tuple[str, str]: | |
data = self.data[idx] | |
image_path = data["image_path"] | |
caption = data["caption"] | |
return image_path, caption | |
def __iter__(self): | |
self.current_idx = 0 | |
return self | |
def __next__(self) -> callable: | |
if self.current_idx >= len(self.data): | |
raise StopIteration | |
if self.caption_only: | |
def create_caption_fetcher(index): | |
return lambda: self.get_caption(index) | |
fetcher = create_caption_fetcher(self.current_idx) | |
else: | |
def create_fetcher(index): | |
return lambda: self.get_image_data(index) | |
fetcher = create_fetcher(self.current_idx) | |
self.current_idx += 1 | |
return fetcher | |
class VideoDatasource(ContentDatasource): | |
def __init__(self): | |
super().__init__() | |
# None means all frames | |
self.start_frame = None | |
self.end_frame = None | |
self.bucket_selector = None | |
def __len__(self): | |
raise NotImplementedError | |
def get_video_data_from_path( | |
self, | |
video_path: str, | |
start_frame: Optional[int] = None, | |
end_frame: Optional[int] = None, | |
bucket_selector: Optional[BucketSelector] = None, | |
) -> tuple[str, list[Image.Image], str]: | |
# this method can resize the video if bucket_selector is given to reduce the memory usage | |
start_frame = start_frame if start_frame is not None else self.start_frame | |
end_frame = end_frame if end_frame is not None else self.end_frame | |
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector | |
video = load_video(video_path, start_frame, end_frame, bucket_selector) | |
return video | |
def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]): | |
self.start_frame = start_frame | |
self.end_frame = end_frame | |
def set_bucket_selector(self, bucket_selector: BucketSelector): | |
self.bucket_selector = bucket_selector | |
def __iter__(self): | |
raise NotImplementedError | |
def __next__(self): | |
raise NotImplementedError | |
class VideoDirectoryDatasource(VideoDatasource): | |
def __init__(self, video_directory: str, caption_extension: Optional[str] = None): | |
super().__init__() | |
self.video_directory = video_directory | |
self.caption_extension = caption_extension | |
self.current_idx = 0 | |
# glob images | |
logger.info(f"glob images in {self.video_directory}") | |
self.video_paths = glob_videos(self.video_directory) | |
logger.info(f"found {len(self.video_paths)} videos") | |
def is_indexable(self): | |
return True | |
def __len__(self): | |
return len(self.video_paths) | |
def get_video_data( | |
self, | |
idx: int, | |
start_frame: Optional[int] = None, | |
end_frame: Optional[int] = None, | |
bucket_selector: Optional[BucketSelector] = None, | |
) -> tuple[str, list[Image.Image], str]: | |
video_path = self.video_paths[idx] | |
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) | |
_, caption = self.get_caption(idx) | |
return video_path, video, caption | |
def get_caption(self, idx: int) -> tuple[str, str]: | |
video_path = self.video_paths[idx] | |
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else "" | |
with open(caption_path, "r", encoding="utf-8") as f: | |
caption = f.read().strip() | |
return video_path, caption | |
def __iter__(self): | |
self.current_idx = 0 | |
return self | |
def __next__(self): | |
if self.current_idx >= len(self.video_paths): | |
raise StopIteration | |
if self.caption_only: | |
def create_caption_fetcher(index): | |
return lambda: self.get_caption(index) | |
fetcher = create_caption_fetcher(self.current_idx) | |
else: | |
def create_fetcher(index): | |
return lambda: self.get_video_data(index) | |
fetcher = create_fetcher(self.current_idx) | |
self.current_idx += 1 | |
return fetcher | |
class VideoJsonlDatasource(VideoDatasource): | |
def __init__(self, video_jsonl_file: str): | |
super().__init__() | |
self.video_jsonl_file = video_jsonl_file | |
self.current_idx = 0 | |
# load jsonl | |
logger.info(f"load video jsonl from {self.video_jsonl_file}") | |
self.data = [] | |
with open(self.video_jsonl_file, "r", encoding="utf-8") as f: | |
for line in f: | |
data = json.loads(line) | |
self.data.append(data) | |
logger.info(f"loaded {len(self.data)} videos") | |
def is_indexable(self): | |
return True | |
def __len__(self): | |
return len(self.data) | |
def get_video_data( | |
self, | |
idx: int, | |
start_frame: Optional[int] = None, | |
end_frame: Optional[int] = None, | |
bucket_selector: Optional[BucketSelector] = None, | |
) -> tuple[str, list[Image.Image], str]: | |
data = self.data[idx] | |
video_path = data["video_path"] | |
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) | |
caption = data["caption"] | |
return video_path, video, caption | |
def get_caption(self, idx: int) -> tuple[str, str]: | |
data = self.data[idx] | |
video_path = data["video_path"] | |
caption = data["caption"] | |
return video_path, caption | |
def __iter__(self): | |
self.current_idx = 0 | |
return self | |
def __next__(self): | |
if self.current_idx >= len(self.data): | |
raise StopIteration | |
if self.caption_only: | |
def create_caption_fetcher(index): | |
return lambda: self.get_caption(index) | |
fetcher = create_caption_fetcher(self.current_idx) | |
else: | |
def create_fetcher(index): | |
return lambda: self.get_video_data(index) | |
fetcher = create_fetcher(self.current_idx) | |
self.current_idx += 1 | |
return fetcher | |
class BaseDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
resolution: Tuple[int, int] = (960, 544), | |
caption_extension: Optional[str] = None, | |
batch_size: int = 1, | |
num_repeats: int = 1, | |
enable_bucket: bool = False, | |
bucket_no_upscale: bool = False, | |
cache_directory: Optional[str] = None, | |
debug_dataset: bool = False, | |
): | |
self.resolution = resolution | |
self.caption_extension = caption_extension | |
self.batch_size = batch_size | |
self.num_repeats = num_repeats | |
self.enable_bucket = enable_bucket | |
self.bucket_no_upscale = bucket_no_upscale | |
self.cache_directory = cache_directory | |
self.debug_dataset = debug_dataset | |
self.seed = None | |
self.current_epoch = 0 | |
if not self.enable_bucket: | |
self.bucket_no_upscale = False | |
def get_metadata(self) -> dict: | |
metadata = { | |
"resolution": self.resolution, | |
"caption_extension": self.caption_extension, | |
"batch_size_per_device": self.batch_size, | |
"num_repeats": self.num_repeats, | |
"enable_bucket": bool(self.enable_bucket), | |
"bucket_no_upscale": bool(self.bucket_no_upscale), | |
} | |
return metadata | |
def get_all_latent_cache_files(self): | |
return glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) | |
def get_all_text_encoder_output_cache_files(self): | |
return glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")) | |
def get_latent_cache_path(self, item_info: ItemInfo) -> str: | |
""" | |
Returns the cache path for the latent tensor. | |
item_info: ItemInfo object | |
Returns: | |
str: cache path | |
cache_path is based on the item_key and the resolution. | |
""" | |
w, h = item_info.original_size | |
basename = os.path.splitext(os.path.basename(item_info.item_key))[0] | |
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" | |
return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors") | |
def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str: | |
basename = os.path.splitext(os.path.basename(item_info.item_key))[0] | |
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" | |
return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors") | |
def retrieve_latent_cache_batches(self, num_workers: int): | |
raise NotImplementedError | |
def retrieve_text_encoder_output_cache_batches(self, num_workers: int): | |
raise NotImplementedError | |
def prepare_for_training(self): | |
pass | |
def set_seed(self, seed: int): | |
self.seed = seed | |
def set_current_epoch(self, epoch): | |
if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented | |
if epoch > self.current_epoch: | |
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) | |
num_epochs = epoch - self.current_epoch | |
for _ in range(num_epochs): | |
self.current_epoch += 1 | |
self.shuffle_buckets() | |
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? | |
else: | |
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) | |
self.current_epoch = epoch | |
def set_current_step(self, step): | |
self.current_step = step | |
def set_max_train_steps(self, max_train_steps): | |
self.max_train_steps = max_train_steps | |
def shuffle_buckets(self): | |
raise NotImplementedError | |
def __len__(self): | |
return NotImplementedError | |
def __getitem__(self, idx): | |
raise NotImplementedError | |
def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int): | |
datasource.set_caption_only(True) | |
executor = ThreadPoolExecutor(max_workers=num_workers) | |
data: list[ItemInfo] = [] | |
futures = [] | |
def aggregate_future(consume_all: bool = False): | |
while len(futures) >= num_workers or (consume_all and len(futures) > 0): | |
completed_futures = [future for future in futures if future.done()] | |
if len(completed_futures) == 0: | |
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures | |
time.sleep(0.1) | |
continue | |
else: | |
break # submit batch if possible | |
for future in completed_futures: | |
item_key, caption = future.result() | |
item_info = ItemInfo(item_key, caption, (0, 0), (0, 0)) | |
item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info) | |
data.append(item_info) | |
futures.remove(future) | |
def submit_batch(flush: bool = False): | |
nonlocal data | |
if len(data) >= batch_size or (len(data) > 0 and flush): | |
batch = data[0:batch_size] | |
if len(data) > batch_size: | |
data = data[batch_size:] | |
else: | |
data = [] | |
return batch | |
return None | |
for fetch_op in datasource: | |
future = executor.submit(fetch_op) | |
futures.append(future) | |
aggregate_future() | |
while True: | |
batch = submit_batch() | |
if batch is None: | |
break | |
yield batch | |
aggregate_future(consume_all=True) | |
while True: | |
batch = submit_batch(flush=True) | |
if batch is None: | |
break | |
yield batch | |
executor.shutdown() | |
class ImageDataset(BaseDataset): | |
def __init__( | |
self, | |
resolution: Tuple[int, int], | |
caption_extension: Optional[str], | |
batch_size: int, | |
num_repeats: int, | |
enable_bucket: bool, | |
bucket_no_upscale: bool, | |
image_directory: Optional[str] = None, | |
image_jsonl_file: Optional[str] = None, | |
cache_directory: Optional[str] = None, | |
debug_dataset: bool = False, | |
): | |
super(ImageDataset, self).__init__( | |
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset | |
) | |
self.image_directory = image_directory | |
self.image_jsonl_file = image_jsonl_file | |
if image_directory is not None: | |
self.datasource = ImageDirectoryDatasource(image_directory, caption_extension) | |
elif image_jsonl_file is not None: | |
self.datasource = ImageJsonlDatasource(image_jsonl_file) | |
else: | |
raise ValueError("image_directory or image_jsonl_file must be specified") | |
if self.cache_directory is None: | |
self.cache_directory = self.image_directory | |
self.batch_manager = None | |
self.num_train_items = 0 | |
def get_metadata(self): | |
metadata = super().get_metadata() | |
if self.image_directory is not None: | |
metadata["image_directory"] = os.path.basename(self.image_directory) | |
if self.image_jsonl_file is not None: | |
metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file) | |
return metadata | |
def get_total_image_count(self): | |
return len(self.datasource) if self.datasource.is_indexable() else None | |
def retrieve_latent_cache_batches(self, num_workers: int): | |
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) | |
executor = ThreadPoolExecutor(max_workers=num_workers) | |
batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo] | |
futures = [] | |
# aggregate futures and sort by bucket resolution | |
def aggregate_future(consume_all: bool = False): | |
while len(futures) >= num_workers or (consume_all and len(futures) > 0): | |
completed_futures = [future for future in futures if future.done()] | |
if len(completed_futures) == 0: | |
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures | |
time.sleep(0.1) | |
continue | |
else: | |
break # submit batch if possible | |
for future in completed_futures: | |
original_size, item_key, image, caption = future.result() | |
bucket_height, bucket_width = image.shape[:2] | |
bucket_reso = (bucket_width, bucket_height) | |
item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image) | |
item_info.latent_cache_path = self.get_latent_cache_path(item_info) | |
if bucket_reso not in batches: | |
batches[bucket_reso] = [] | |
batches[bucket_reso].append(item_info) | |
futures.remove(future) | |
# submit batch if some bucket has enough items | |
def submit_batch(flush: bool = False): | |
for key in batches: | |
if len(batches[key]) >= self.batch_size or flush: | |
batch = batches[key][0 : self.batch_size] | |
if len(batches[key]) > self.batch_size: | |
batches[key] = batches[key][self.batch_size :] | |
else: | |
del batches[key] | |
return key, batch | |
return None, None | |
for fetch_op in self.datasource: | |
# fetch and resize image in a separate thread | |
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]: | |
image_key, image, caption = op() | |
image: Image.Image | |
image_size = image.size | |
bucket_reso = buckset_selector.get_bucket_resolution(image_size) | |
image = resize_image_to_bucket(image, bucket_reso) | |
return image_size, image_key, image, caption | |
future = executor.submit(fetch_and_resize, fetch_op) | |
futures.append(future) | |
aggregate_future() | |
while True: | |
key, batch = submit_batch() | |
if key is None: | |
break | |
yield key, batch | |
aggregate_future(consume_all=True) | |
while True: | |
key, batch = submit_batch(flush=True) | |
if key is None: | |
break | |
yield key, batch | |
executor.shutdown() | |
def retrieve_text_encoder_output_cache_batches(self, num_workers: int): | |
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) | |
def prepare_for_training(self): | |
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) | |
# glob cache files | |
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) | |
# assign cache files to item info | |
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo] | |
for cache_file in latent_cache_files: | |
tokens = os.path.basename(cache_file).split("_") | |
image_size = tokens[-2] # 0000x0000 | |
image_width, image_height = map(int, image_size.split("x")) | |
image_size = (image_width, image_height) | |
item_key = "_".join(tokens[:-2]) | |
text_encoder_output_cache_file = os.path.join( | |
self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" | |
) | |
if not os.path.exists(text_encoder_output_cache_file): | |
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") | |
continue | |
bucket_reso = bucket_selector.get_bucket_resolution(image_size) | |
item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file) | |
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file | |
bucket = bucketed_item_info.get(bucket_reso, []) | |
for _ in range(self.num_repeats): | |
bucket.append(item_info) | |
bucketed_item_info[bucket_reso] = bucket | |
# prepare batch manager | |
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) | |
self.batch_manager.show_bucket_info() | |
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) | |
def shuffle_buckets(self): | |
# set random seed for this epoch | |
random.seed(self.seed + self.current_epoch) | |
self.batch_manager.shuffle() | |
def __len__(self): | |
if self.batch_manager is None: | |
return 100 # dummy value | |
return len(self.batch_manager) | |
def __getitem__(self, idx): | |
return self.batch_manager[idx] | |
class VideoDataset(BaseDataset): | |
def __init__( | |
self, | |
resolution: Tuple[int, int], | |
caption_extension: Optional[str], | |
batch_size: int, | |
num_repeats: int, | |
enable_bucket: bool, | |
bucket_no_upscale: bool, | |
frame_extraction: Optional[str] = "head", | |
frame_stride: Optional[int] = 1, | |
frame_sample: Optional[int] = 1, | |
target_frames: Optional[list[int]] = None, | |
video_directory: Optional[str] = None, | |
video_jsonl_file: Optional[str] = None, | |
cache_directory: Optional[str] = None, | |
debug_dataset: bool = False, | |
): | |
super(VideoDataset, self).__init__( | |
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset | |
) | |
self.video_directory = video_directory | |
self.video_jsonl_file = video_jsonl_file | |
self.target_frames = target_frames | |
self.frame_extraction = frame_extraction | |
self.frame_stride = frame_stride | |
self.frame_sample = frame_sample | |
if video_directory is not None: | |
self.datasource = VideoDirectoryDatasource(video_directory, caption_extension) | |
elif video_jsonl_file is not None: | |
self.datasource = VideoJsonlDatasource(video_jsonl_file) | |
if self.frame_extraction == "uniform" and self.frame_sample == 1: | |
self.frame_extraction = "head" | |
logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.") | |
if self.frame_extraction == "head": | |
# head extraction. we can limit the number of frames to be extracted | |
self.datasource.set_start_and_end_frame(0, max(self.target_frames)) | |
if self.cache_directory is None: | |
self.cache_directory = self.video_directory | |
self.batch_manager = None | |
self.num_train_items = 0 | |
def get_metadata(self): | |
metadata = super().get_metadata() | |
if self.video_directory is not None: | |
metadata["video_directory"] = os.path.basename(self.video_directory) | |
if self.video_jsonl_file is not None: | |
metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file) | |
metadata["frame_extraction"] = self.frame_extraction | |
metadata["frame_stride"] = self.frame_stride | |
metadata["frame_sample"] = self.frame_sample | |
metadata["target_frames"] = self.target_frames | |
return metadata | |
def retrieve_latent_cache_batches(self, num_workers: int): | |
buckset_selector = BucketSelector(self.resolution) | |
self.datasource.set_bucket_selector(buckset_selector) | |
executor = ThreadPoolExecutor(max_workers=num_workers) | |
# key: (width, height, frame_count), value: [ItemInfo] | |
batches: dict[tuple[int, int, int], list[ItemInfo]] = {} | |
futures = [] | |
def aggregate_future(consume_all: bool = False): | |
while len(futures) >= num_workers or (consume_all and len(futures) > 0): | |
completed_futures = [future for future in futures if future.done()] | |
if len(completed_futures) == 0: | |
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures | |
time.sleep(0.1) | |
continue | |
else: | |
break # submit batch if possible | |
for future in completed_futures: | |
original_frame_size, video_key, video, caption = future.result() | |
frame_count = len(video) | |
video = np.stack(video, axis=0) | |
height, width = video.shape[1:3] | |
bucket_reso = (width, height) # already resized | |
crop_pos_and_frames = [] | |
if self.frame_extraction == "head": | |
for target_frame in self.target_frames: | |
if frame_count >= target_frame: | |
crop_pos_and_frames.append((0, target_frame)) | |
elif self.frame_extraction == "chunk": | |
# split by target_frames | |
for target_frame in self.target_frames: | |
for i in range(0, frame_count, target_frame): | |
if i + target_frame <= frame_count: | |
crop_pos_and_frames.append((i, target_frame)) | |
elif self.frame_extraction == "slide": | |
# slide window | |
for target_frame in self.target_frames: | |
if frame_count >= target_frame: | |
for i in range(0, frame_count - target_frame + 1, self.frame_stride): | |
crop_pos_and_frames.append((i, target_frame)) | |
elif self.frame_extraction == "uniform": | |
# select N frames uniformly | |
for target_frame in self.target_frames: | |
if frame_count >= target_frame: | |
frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int) | |
for i in frame_indices: | |
crop_pos_and_frames.append((i, target_frame)) | |
else: | |
raise ValueError(f"frame_extraction {self.frame_extraction} is not supported") | |
for crop_pos, target_frame in crop_pos_and_frames: | |
cropped_video = video[crop_pos : crop_pos + target_frame] | |
body, ext = os.path.splitext(video_key) | |
item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}" | |
batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count | |
item_info = ItemInfo( | |
item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video | |
) | |
item_info.latent_cache_path = self.get_latent_cache_path(item_info) | |
batch = batches.get(batch_key, []) | |
batch.append(item_info) | |
batches[batch_key] = batch | |
futures.remove(future) | |
def submit_batch(flush: bool = False): | |
for key in batches: | |
if len(batches[key]) >= self.batch_size or flush: | |
batch = batches[key][0 : self.batch_size] | |
if len(batches[key]) > self.batch_size: | |
batches[key] = batches[key][self.batch_size :] | |
else: | |
del batches[key] | |
return key, batch | |
return None, None | |
for operator in self.datasource: | |
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]: | |
video_key, video, caption = op() | |
video: list[np.ndarray] | |
frame_size = (video[0].shape[1], video[0].shape[0]) | |
# resize if necessary | |
bucket_reso = buckset_selector.get_bucket_resolution(frame_size) | |
video = [resize_image_to_bucket(frame, bucket_reso) for frame in video] | |
return frame_size, video_key, video, caption | |
future = executor.submit(fetch_and_resize, operator) | |
futures.append(future) | |
aggregate_future() | |
while True: | |
key, batch = submit_batch() | |
if key is None: | |
break | |
yield key, batch | |
aggregate_future(consume_all=True) | |
while True: | |
key, batch = submit_batch(flush=True) | |
if key is None: | |
break | |
yield key, batch | |
executor.shutdown() | |
def retrieve_text_encoder_output_cache_batches(self, num_workers: int): | |
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) | |
def prepare_for_training(self): | |
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) | |
# glob cache files | |
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) | |
# assign cache files to item info | |
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo] | |
for cache_file in latent_cache_files: | |
tokens = os.path.basename(cache_file).split("_") | |
image_size = tokens[-2] # 0000x0000 | |
image_width, image_height = map(int, image_size.split("x")) | |
image_size = (image_width, image_height) | |
frame_pos, frame_count = tokens[-3].split("-") | |
frame_pos, frame_count = int(frame_pos), int(frame_count) | |
item_key = "_".join(tokens[:-3]) | |
text_encoder_output_cache_file = os.path.join( | |
self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" | |
) | |
if not os.path.exists(text_encoder_output_cache_file): | |
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") | |
continue | |
bucket_reso = bucket_selector.get_bucket_resolution(image_size) | |
bucket_reso = (*bucket_reso, frame_count) | |
item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file) | |
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file | |
bucket = bucketed_item_info.get(bucket_reso, []) | |
for _ in range(self.num_repeats): | |
bucket.append(item_info) | |
bucketed_item_info[bucket_reso] = bucket | |
# prepare batch manager | |
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) | |
self.batch_manager.show_bucket_info() | |
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) | |
def shuffle_buckets(self): | |
# set random seed for this epoch | |
random.seed(self.seed + self.current_epoch) | |
self.batch_manager.shuffle() | |
def __len__(self): | |
if self.batch_manager is None: | |
return 100 # dummy value | |
return len(self.batch_manager) | |
def __getitem__(self, idx): | |
return self.batch_manager[idx] | |
class DatasetGroup(torch.utils.data.ConcatDataset): | |
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]): | |
super().__init__(datasets) | |
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets | |
self.num_train_items = 0 | |
for dataset in self.datasets: | |
self.num_train_items += dataset.num_train_items | |
def set_current_epoch(self, epoch): | |
for dataset in self.datasets: | |
dataset.set_current_epoch(epoch) | |
def set_current_step(self, step): | |
for dataset in self.datasets: | |
dataset.set_current_step(step) | |
def set_max_train_steps(self, max_train_steps): | |
for dataset in self.datasets: | |
dataset.set_max_train_steps(max_train_steps) | |