Spaces:
Running
Running
import ast | |
import asyncio | |
from datetime import datetime | |
import gc | |
import importlib | |
import argparse | |
import math | |
import os | |
import pathlib | |
import re | |
import sys | |
import random | |
import time | |
import json | |
from multiprocessing import Value | |
from typing import Any, Dict, List, Optional | |
import accelerate | |
import numpy as np | |
from packaging.version import Version | |
import huggingface_hub | |
import toml | |
import torch | |
from tqdm import tqdm | |
from accelerate.utils import set_seed | |
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs | |
from safetensors.torch import load_file, save_file | |
import transformers | |
from diffusers.optimization import ( | |
SchedulerType as DiffusersSchedulerType, | |
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, | |
) | |
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION | |
from dataset import config_utils | |
from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape | |
import hunyuan_model.text_encoder as text_encoder_module | |
from hunyuan_model.vae import load_vae | |
import hunyuan_model.vae as vae_module | |
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler | |
import networks.lora as lora_module | |
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer | |
import logging | |
from utils import huggingface_utils, model_utils, train_utils, sai_model_spec | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video" | |
# TODO make separate file for some functions to commonize with other scripts | |
def clean_memory_on_device(device: torch.device): | |
r""" | |
Clean memory on the specified device, will be called from training scripts. | |
""" | |
gc.collect() | |
# device may "cuda" or "cuda:0", so we need to check the type of device | |
if device.type == "cuda": | |
torch.cuda.empty_cache() | |
if device.type == "xpu": | |
torch.xpu.empty_cache() | |
if device.type == "mps": | |
torch.mps.empty_cache() | |
# for collate_fn: epoch and step is multiprocessing.Value | |
class collator_class: | |
def __init__(self, epoch, step, dataset): | |
self.current_epoch = epoch | |
self.current_step = step | |
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing | |
def __call__(self, examples): | |
worker_info = torch.utils.data.get_worker_info() | |
# worker_info is None in the main process | |
if worker_info is not None: | |
dataset = worker_info.dataset | |
else: | |
dataset = self.dataset | |
# set epoch and step | |
dataset.set_current_epoch(self.current_epoch.value) | |
dataset.set_current_step(self.current_step.value) | |
return examples[0] | |
def prepare_accelerator(args: argparse.Namespace) -> Accelerator: | |
""" | |
DeepSpeed is not supported in this script currently. | |
""" | |
if args.logging_dir is None: | |
logging_dir = None | |
else: | |
log_prefix = "" if args.log_prefix is None else args.log_prefix | |
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) | |
if args.log_with is None: | |
if logging_dir is not None: | |
log_with = "tensorboard" | |
else: | |
log_with = None | |
else: | |
log_with = args.log_with | |
if log_with in ["tensorboard", "all"]: | |
if logging_dir is None: | |
raise ValueError( | |
"logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" | |
) | |
if log_with in ["wandb", "all"]: | |
try: | |
import wandb | |
except ImportError: | |
raise ImportError("No wandb / wandb がインストールされていないようです") | |
if logging_dir is not None: | |
os.makedirs(logging_dir, exist_ok=True) | |
os.environ["WANDB_DIR"] = logging_dir | |
if args.wandb_api_key is not None: | |
wandb.login(key=args.wandb_api_key) | |
kwargs_handlers = [ | |
( | |
InitProcessGroupKwargs( | |
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", | |
init_method=( | |
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None | |
), | |
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, | |
) | |
if torch.cuda.device_count() > 1 | |
else None | |
), | |
( | |
DistributedDataParallelKwargs( | |
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph | |
) | |
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph | |
else None | |
), | |
] | |
kwargs_handlers = [i for i in kwargs_handlers if i is not None] | |
accelerator = Accelerator( | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
mixed_precision=args.mixed_precision, | |
log_with=log_with, | |
project_dir=logging_dir, | |
kwargs_handlers=kwargs_handlers, | |
) | |
print("accelerator device:", accelerator.device) | |
return accelerator | |
def line_to_prompt_dict(line: str) -> dict: | |
# subset of gen_img_diffusers | |
prompt_args = line.split(" --") | |
prompt_dict = {} | |
prompt_dict["prompt"] = prompt_args[0] | |
for parg in prompt_args: | |
try: | |
m = re.match(r"w (\d+)", parg, re.IGNORECASE) | |
if m: | |
prompt_dict["width"] = int(m.group(1)) | |
continue | |
m = re.match(r"h (\d+)", parg, re.IGNORECASE) | |
if m: | |
prompt_dict["height"] = int(m.group(1)) | |
continue | |
m = re.match(r"f (\d+)", parg, re.IGNORECASE) | |
if m: | |
prompt_dict["frame_count"] = int(m.group(1)) | |
continue | |
m = re.match(r"d (\d+)", parg, re.IGNORECASE) | |
if m: | |
prompt_dict["seed"] = int(m.group(1)) | |
continue | |
m = re.match(r"s (\d+)", parg, re.IGNORECASE) | |
if m: # steps | |
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) | |
continue | |
# m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) | |
# if m: # scale | |
# prompt_dict["scale"] = float(m.group(1)) | |
# continue | |
# m = re.match(r"n (.+)", parg, re.IGNORECASE) | |
# if m: # negative prompt | |
# prompt_dict["negative_prompt"] = m.group(1) | |
# continue | |
except ValueError as ex: | |
logger.error(f"Exception in parsing / 解析エラー: {parg}") | |
logger.error(ex) | |
return prompt_dict | |
def load_prompts(prompt_file: str) -> list[Dict]: | |
# read prompts | |
if prompt_file.endswith(".txt"): | |
with open(prompt_file, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] | |
elif prompt_file.endswith(".toml"): | |
with open(prompt_file, "r", encoding="utf-8") as f: | |
data = toml.load(f) | |
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] | |
elif prompt_file.endswith(".json"): | |
with open(prompt_file, "r", encoding="utf-8") as f: | |
prompts = json.load(f) | |
# preprocess prompts | |
for i in range(len(prompts)): | |
prompt_dict = prompts[i] | |
if isinstance(prompt_dict, str): | |
prompt_dict = line_to_prompt_dict(prompt_dict) | |
prompts[i] = prompt_dict | |
assert isinstance(prompt_dict, dict) | |
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. | |
prompt_dict["enum"] = i | |
prompt_dict.pop("subset", None) | |
return prompts | |
def compute_density_for_timestep_sampling( | |
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None | |
): | |
"""Compute the density for sampling the timesteps when doing SD3 training. | |
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. | |
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. | |
""" | |
if weighting_scheme == "logit_normal": | |
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). | |
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") | |
u = torch.nn.functional.sigmoid(u) | |
elif weighting_scheme == "mode": | |
u = torch.rand(size=(batch_size,), device="cpu") | |
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) | |
else: | |
u = torch.rand(size=(batch_size,), device="cpu") | |
return u | |
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): | |
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) | |
schedule_timesteps = noise_scheduler.timesteps.to(device) | |
timesteps = timesteps.to(device) | |
# if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps): | |
if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]): | |
# raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません") | |
# round to nearest timestep | |
logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません") | |
step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps] | |
else: | |
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
sigma = sigmas[step_indices].flatten() | |
while len(sigma.shape) < n_dim: | |
sigma = sigma.unsqueeze(-1) | |
return sigma | |
def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype): | |
"""Computes loss weighting scheme for SD3 training. | |
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. | |
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. | |
""" | |
if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap": | |
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype) | |
if weighting_scheme == "sigma_sqrt": | |
weighting = (sigmas**-2.0).float() | |
else: | |
bot = 1 - 2 * sigmas + 2 * sigmas**2 | |
weighting = 2 / (math.pi * bot) | |
else: | |
weighting = None # torch.ones_like(sigmas) | |
return weighting | |
class FineTuningTrainer: | |
def __init__(self): | |
pass | |
def process_sample_prompts( | |
self, | |
args: argparse.Namespace, | |
accelerator: Accelerator, | |
sample_prompts: str, | |
text_encoder1: str, | |
text_encoder2: str, | |
fp8_llm: bool, | |
): | |
logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}") | |
prompts = load_prompts(sample_prompts) | |
def encode_for_text_encoder(text_encoder, is_llm=True): | |
sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask) | |
with accelerator.autocast(), torch.no_grad(): | |
for prompt_dict in prompts: | |
for p in [prompt_dict.get("prompt", "")]: | |
if p not in sample_prompts_te_outputs: | |
logger.info(f"cache Text Encoder outputs for prompt: {p}") | |
data_type = "video" | |
text_inputs = text_encoder.text2tokens(p, data_type=data_type) | |
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) | |
sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask) | |
return sample_prompts_te_outputs | |
# Load Text Encoder 1 and encode | |
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype) | |
logger.info(f"loading text encoder 1: {text_encoder1}") | |
text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype) | |
logger.info("encoding with Text Encoder 1") | |
te_outputs_1 = encode_for_text_encoder(text_encoder_1) | |
del text_encoder_1 | |
# Load Text Encoder 2 and encode | |
logger.info(f"loading text encoder 2: {text_encoder2}") | |
text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype) | |
logger.info("encoding with Text Encoder 2") | |
te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False) | |
del text_encoder_2 | |
# prepare sample parameters | |
sample_parameters = [] | |
for prompt_dict in prompts: | |
prompt_dict_copy = prompt_dict.copy() | |
p = prompt_dict.get("prompt", "") | |
prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0] | |
prompt_dict_copy["llm_mask"] = te_outputs_1[p][1] | |
prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0] | |
prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1] | |
sample_parameters.append(prompt_dict_copy) | |
clean_memory_on_device(accelerator.device) | |
return sample_parameters | |
def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]: | |
# adamw, adamw8bit, adafactor | |
optimizer_type = args.optimizer_type.lower() | |
# split optimizer_type and optimizer_args | |
optimizer_kwargs = {} | |
if args.optimizer_args is not None and len(args.optimizer_args) > 0: | |
for arg in args.optimizer_args: | |
key, value = arg.split("=") | |
value = ast.literal_eval(value) | |
optimizer_kwargs[key] = value | |
lr = args.learning_rate | |
optimizer = None | |
optimizer_class = None | |
if optimizer_type.endswith("8bit".lower()): | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") | |
if optimizer_type == "AdamW8bit".lower(): | |
logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") | |
optimizer_class = bnb.optim.AdamW8bit | |
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) | |
elif optimizer_type == "Adafactor".lower(): | |
# Adafactor: check relative_step and warmup_init | |
if "relative_step" not in optimizer_kwargs: | |
optimizer_kwargs["relative_step"] = True # default | |
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): | |
logger.info( | |
f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" | |
) | |
optimizer_kwargs["relative_step"] = True | |
logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") | |
if optimizer_kwargs["relative_step"]: | |
logger.info(f"relative_step is true / relative_stepがtrueです") | |
if lr != 0.0: | |
logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") | |
args.learning_rate = None | |
if args.lr_scheduler != "adafactor": | |
logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") | |
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど | |
lr = None | |
else: | |
if args.max_grad_norm != 0.0: | |
logger.warning( | |
f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" | |
) | |
if args.lr_scheduler != "constant_with_warmup": | |
logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") | |
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: | |
logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") | |
optimizer_class = transformers.optimization.Adafactor | |
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) | |
elif optimizer_type == "AdamW".lower(): | |
logger.info(f"use AdamW optimizer | {optimizer_kwargs}") | |
optimizer_class = torch.optim.AdamW | |
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) | |
if optimizer is None: | |
# 任意のoptimizerを使う | |
case_sensitive_optimizer_type = args.optimizer_type # not lower | |
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") | |
if "." not in case_sensitive_optimizer_type: # from torch.optim | |
optimizer_module = torch.optim | |
else: # from other library | |
values = case_sensitive_optimizer_type.split(".") | |
optimizer_module = importlib.import_module(".".join(values[:-1])) | |
case_sensitive_optimizer_type = values[-1] | |
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) | |
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) | |
# for logging | |
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ | |
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) | |
# get train and eval functions | |
if hasattr(optimizer, "train") and callable(optimizer.train): | |
train_fn = optimizer.train | |
eval_fn = optimizer.eval | |
else: | |
train_fn = lambda: None | |
eval_fn = lambda: None | |
return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn | |
def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool: | |
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper | |
def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any: | |
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. | |
# this scheduler is used for logging only. | |
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler | |
class DummyScheduler: | |
def __init__(self, optimizer: torch.optim.Optimizer): | |
self.optimizer = optimizer | |
def step(self): | |
pass | |
def get_last_lr(self): | |
return [group["lr"] for group in self.optimizer.param_groups] | |
return DummyScheduler(optimizer) | |
def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int): | |
""" | |
Unified API to get any scheduler from its name. | |
""" | |
# if schedulefree optimizer, return dummy scheduler | |
if self.is_schedulefree_optimizer(optimizer, args): | |
return self.get_dummy_scheduler(optimizer) | |
name = args.lr_scheduler | |
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps | |
num_warmup_steps: Optional[int] = ( | |
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps | |
) | |
num_decay_steps: Optional[int] = ( | |
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps | |
) | |
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps | |
num_cycles = args.lr_scheduler_num_cycles | |
power = args.lr_scheduler_power | |
timescale = args.lr_scheduler_timescale | |
min_lr_ratio = args.lr_scheduler_min_lr_ratio | |
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs | |
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: | |
for arg in args.lr_scheduler_args: | |
key, value = arg.split("=") | |
value = ast.literal_eval(value) | |
lr_scheduler_kwargs[key] = value | |
def wrap_check_needless_num_warmup_steps(return_vals): | |
if num_warmup_steps is not None and num_warmup_steps != 0: | |
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") | |
return return_vals | |
# using any lr_scheduler from other library | |
if args.lr_scheduler_type: | |
lr_scheduler_type = args.lr_scheduler_type | |
logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") | |
if "." not in lr_scheduler_type: # default to use torch.optim | |
lr_scheduler_module = torch.optim.lr_scheduler | |
else: | |
values = lr_scheduler_type.split(".") | |
lr_scheduler_module = importlib.import_module(".".join(values[:-1])) | |
lr_scheduler_type = values[-1] | |
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) | |
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) | |
return lr_scheduler | |
if name.startswith("adafactor"): | |
assert ( | |
type(optimizer) == transformers.optimization.Adafactor | |
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" | |
initial_lr = float(name.split(":")[1]) | |
# logger.info(f"adafactor scheduler init lr {initial_lr}") | |
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) | |
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: | |
name = DiffusersSchedulerType(name) | |
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] | |
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs | |
name = SchedulerType(name) | |
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] | |
if name == SchedulerType.CONSTANT: | |
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) | |
# All other schedulers require `num_warmup_steps` | |
if num_warmup_steps is None: | |
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") | |
if name == SchedulerType.CONSTANT_WITH_WARMUP: | |
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) | |
if name == SchedulerType.INVERSE_SQRT: | |
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) | |
# All other schedulers require `num_training_steps` | |
if num_training_steps is None: | |
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") | |
if name == SchedulerType.COSINE_WITH_RESTARTS: | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
num_cycles=num_cycles, | |
**lr_scheduler_kwargs, | |
) | |
if name == SchedulerType.POLYNOMIAL: | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
power=power, | |
**lr_scheduler_kwargs, | |
) | |
if name == SchedulerType.COSINE_WITH_MIN_LR: | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
num_cycles=num_cycles / 2, | |
min_lr_rate=min_lr_ratio, | |
**lr_scheduler_kwargs, | |
) | |
# these schedulers do not require `num_decay_steps` | |
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
**lr_scheduler_kwargs, | |
) | |
# All other schedulers require `num_decay_steps` | |
if num_decay_steps is None: | |
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") | |
if name == SchedulerType.WARMUP_STABLE_DECAY: | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_stable_steps=num_stable_steps, | |
num_decay_steps=num_decay_steps, | |
num_cycles=num_cycles / 2, | |
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, | |
**lr_scheduler_kwargs, | |
) | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
num_decay_steps=num_decay_steps, | |
**lr_scheduler_kwargs, | |
) | |
def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool: | |
if not args.resume: | |
return False | |
if not args.resume_from_huggingface: | |
logger.info(f"resume training from local state: {args.resume}") | |
accelerator.load_state(args.resume) | |
return True | |
logger.info(f"resume training from huggingface state: {args.resume}") | |
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] | |
path_in_repo = "/".join(args.resume.split("/")[2:]) | |
revision = None | |
repo_type = None | |
if ":" in path_in_repo: | |
divided = path_in_repo.split(":") | |
if len(divided) == 2: | |
path_in_repo, revision = divided | |
repo_type = "model" | |
else: | |
path_in_repo, revision, repo_type = divided | |
logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") | |
list_files = huggingface_utils.list_dir( | |
repo_id=repo_id, | |
subfolder=path_in_repo, | |
revision=revision, | |
token=args.huggingface_token, | |
repo_type=repo_type, | |
) | |
async def download(filename) -> str: | |
def task(): | |
return huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
revision=revision, | |
repo_type=repo_type, | |
token=args.huggingface_token, | |
) | |
return await asyncio.get_event_loop().run_in_executor(None, task) | |
loop = asyncio.get_event_loop() | |
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) | |
if len(results) == 0: | |
raise ValueError( | |
"No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" | |
) | |
dirname = os.path.dirname(results[0]) | |
accelerator.load_state(dirname) | |
return True | |
def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters): | |
pass | |
def get_noisy_model_input_and_timesteps( | |
self, | |
args: argparse.Namespace, | |
noise: torch.Tensor, | |
latents: torch.Tensor, | |
noise_scheduler: FlowMatchDiscreteScheduler, | |
device: torch.device, | |
dtype: torch.dtype, | |
): | |
batch_size = noise.shape[0] | |
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift": | |
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": | |
# Simple random t-based noise sampling | |
if args.timestep_sampling == "sigmoid": | |
t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device)) | |
else: | |
t = torch.rand((batch_size,), device=device) | |
elif args.timestep_sampling == "shift": | |
shift = args.discrete_flow_shift | |
logits_norm = torch.randn(batch_size, device=device) | |
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling | |
t = logits_norm.sigmoid() | |
t = (t * shift) / (1 + (shift - 1) * t) | |
t_min = args.min_timestep if args.min_timestep is not None else 0 | |
t_max = args.max_timestep if args.max_timestep is not None else 1000.0 | |
t_min /= 1000.0 | |
t_max /= 1000.0 | |
t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1] | |
timesteps = t * 1000.0 | |
t = t.view(-1, 1, 1, 1, 1) | |
noisy_model_input = (1 - t) * latents + t * noise | |
timesteps += 1 # 1 to 1000 | |
else: | |
# Sample a random timestep for each image | |
# for weighting schemes where we sample timesteps non-uniformly | |
u = compute_density_for_timestep_sampling( | |
weighting_scheme=args.weighting_scheme, | |
batch_size=batch_size, | |
logit_mean=args.logit_mean, | |
logit_std=args.logit_std, | |
mode_scale=args.mode_scale, | |
) | |
# indices = (u * noise_scheduler.config.num_train_timesteps).long() | |
t_min = args.min_timestep if args.min_timestep is not None else 0 | |
t_max = args.max_timestep if args.max_timestep is not None else 1000 | |
indices = (u * (t_max - t_min) + t_min).long() | |
timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000 | |
# Add noise according to flow matching. | |
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) | |
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents | |
return noisy_model_input, timesteps | |
def train(self, args): | |
if args.seed is None: | |
args.seed = random.randint(0, 2**32) | |
set_seed(args.seed) | |
# Load dataset config | |
blueprint_generator = BlueprintGenerator(ConfigSanitizer()) | |
logger.info(f"Load dataset config from {args.dataset_config}") | |
user_config = config_utils.load_user_config(args.dataset_config) | |
blueprint = blueprint_generator.generate(user_config, args) | |
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True) | |
current_epoch = Value("i", 0) | |
current_step = Value("i", 0) | |
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None | |
collator = collator_class(current_epoch, current_step, ds_for_collator) | |
# prepare accelerator | |
logger.info("preparing accelerator") | |
accelerator = prepare_accelerator(args) | |
is_main_process = accelerator.is_main_process | |
# prepare dtype | |
weight_dtype = torch.float32 | |
if args.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif args.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# HunyuanVideo specific | |
vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype) | |
# get embedding for sampling images | |
sample_parameters = vae = None | |
if args.sample_prompts: | |
sample_parameters = self.process_sample_prompts( | |
args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm | |
) | |
# Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory | |
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae) | |
vae.requires_grad_(False) | |
vae.eval() | |
if args.vae_chunk_size is not None: | |
vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size) | |
logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE") | |
if args.vae_spatial_tile_sample_min_size is not None: | |
vae.enable_spatial_tiling(True) | |
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size | |
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 | |
elif args.vae_tiling: | |
vae.enable_spatial_tiling(True) | |
# load DiT model | |
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 | |
loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device | |
logger.info(f"Loading DiT model from {args.dit}") | |
if args.sdpa: | |
attn_mode = "torch" | |
elif args.flash_attn: | |
attn_mode = "flash" | |
elif args.sage_attn: | |
attn_mode = "sageattn" | |
elif args.xformers: | |
attn_mode = "xformers" | |
else: | |
raise ValueError( | |
f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください" | |
) | |
transformer = load_transformer(args.dit, attn_mode, args.split_attn, loading_device, None) # load as is | |
if blocks_to_swap > 0: | |
logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}") | |
transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True) | |
transformer.move_to_device_except_swap_blocks(accelerator.device) | |
if args.img_in_txt_in_offloading: | |
logger.info("Enable offloading img_in and txt_in to CPU") | |
transformer.enable_img_in_txt_in_offloading() | |
if args.gradient_checkpointing: | |
transformer.enable_gradient_checkpointing() | |
# prepare optimizer, data loader etc. | |
accelerator.print("prepare optimizer, data loader etc.") | |
transformer.requires_grad_(False) | |
if accelerator.is_main_process: | |
accelerator.print( | |
f"Trainable modules '{args.trainable_modules}'." | |
) | |
for name, param in transformer.named_parameters(): | |
for trainable_module_name in args.trainable_modules: | |
if trainable_module_name in name: | |
param.requires_grad = True | |
break | |
total_params = list(transformer.parameters()) | |
trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters())) | |
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M") | |
optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer( | |
args, trainable_params | |
) | |
# prepare dataloader | |
# num workers for data loader: if 0, persistent_workers is not available | |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset_group, | |
batch_size=1, | |
shuffle=True, | |
collate_fn=collator, | |
num_workers=n_workers, | |
persistent_workers=args.persistent_data_loader_workers, | |
) | |
# calculate max_train_steps | |
if args.max_train_epochs is not None: | |
args.max_train_steps = args.max_train_epochs * math.ceil( | |
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps | |
) | |
accelerator.print( | |
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" | |
) | |
# send max_train_steps to train_dataset_group | |
train_dataset_group.set_max_train_steps(args.max_train_steps) | |
# prepare lr_scheduler | |
lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes) | |
# prepare training model. accelerator does some magic here | |
# experimental feature: train the model with gradients in fp16/bf16 | |
dit_dtype = torch.float32 | |
if args.full_fp16: | |
assert ( | |
args.mixed_precision == "fp16" | |
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" | |
accelerator.print("enable full fp16 training.") | |
dit_weight_dtype = torch.float16 | |
elif args.full_bf16: | |
assert ( | |
args.mixed_precision == "bf16" | |
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" | |
accelerator.print("enable full bf16 training.") | |
dit_weight_dtype = torch.bfloat16 | |
else: | |
dit_weight_dtype = torch.float32 | |
# TODO add fused optimizer and stochastic rounding | |
# cast model to dit_weight_dtype | |
# if dit_dtype != dit_weight_dtype: | |
logger.info(f"casting model to {dit_weight_dtype}") | |
transformer.to(dit_weight_dtype) | |
if blocks_to_swap > 0: | |
transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0]) | |
accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage | |
accelerator.unwrap_model(transformer).prepare_block_swap_before_forward() | |
else: | |
transformer = accelerator.prepare(transformer) | |
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) | |
transformer.train() | |
if args.full_fp16: | |
# patch accelerator for fp16 training | |
# def patch_accelerator_for_fp16_training(accelerator): | |
org_unscale_grads = accelerator.scaler._unscale_grads_ | |
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): | |
return org_unscale_grads(optimizer, inv_scale, found_inf, True) | |
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer | |
# resume from local or huggingface. accelerator.step is set | |
self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume) | |
# epoch数を計算する | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
# 学習する | |
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | |
accelerator.print("running training / 学習開始") | |
accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}") | |
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") | |
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") | |
accelerator.print( | |
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" | |
) | |
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") | |
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") | |
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") | |
if accelerator.is_main_process: | |
init_kwargs = {} | |
if args.wandb_run_name: | |
init_kwargs["wandb"] = {"name": args.wandb_run_name} | |
if args.log_tracker_config is not None: | |
init_kwargs = toml.load(args.log_tracker_config) | |
accelerator.init_trackers( | |
"hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name, | |
config=train_utils.get_sanitized_config_or_none(args), | |
init_kwargs=init_kwargs, | |
) | |
# TODO skip until initial step | |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") | |
epoch_to_start = 0 | |
global_step = 0 | |
noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler") | |
loss_recorder = train_utils.LossRecorder() | |
del train_dataset_group | |
# function for saving/removing | |
def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False): | |
os.makedirs(args.output_dir, exist_ok=True) | |
ckpt_file = os.path.join(args.output_dir, ckpt_name) | |
accelerator.print(f"\nsaving checkpoint: {ckpt_file}") | |
title = args.metadata_title if args.metadata_title is not None else args.output_name | |
if args.min_timestep is not None or args.max_timestep is not None: | |
min_time_step = args.min_timestep if args.min_timestep is not None else 0 | |
max_time_step = args.max_timestep if args.max_timestep is not None else 1000 | |
md_timesteps = (min_time_step, max_time_step) | |
else: | |
md_timesteps = None | |
sai_metadata = sai_model_spec.build_metadata( | |
None, | |
time.time(), | |
title, | |
None, | |
args.metadata_author, | |
args.metadata_description, | |
args.metadata_license, | |
args.metadata_tags, | |
timesteps=md_timesteps, | |
is_lora=False, | |
) | |
save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata) | |
if args.huggingface_repo_id is not None: | |
huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) | |
def remove_model(old_ckpt_name): | |
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) | |
if os.path.exists(old_ckpt_file): | |
accelerator.print(f"removing old checkpoint: {old_ckpt_file}") | |
os.remove(old_ckpt_file) | |
# For --sample_at_first | |
optimizer_eval_fn() | |
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters) | |
optimizer_train_fn() | |
if len(accelerator.trackers) > 0: | |
# log empty object to commit the sample images to wandb | |
accelerator.log({}, step=0) | |
# training loop | |
# log device and dtype for each model | |
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}") | |
clean_memory_on_device(accelerator.device) | |
pos_embed_cache = {} | |
for epoch in range(epoch_to_start, num_train_epochs): | |
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") | |
current_epoch.value = epoch + 1 | |
for step, batch in enumerate(train_dataloader): | |
latents, llm_embeds, llm_mask, clip_embeds = batch | |
bsz = latents.shape[0] | |
current_step.value = global_step | |
with accelerator.accumulate(transformer): | |
latents = latents * vae_module.SCALING_FACTOR | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(latents) | |
# calculate model input and timesteps | |
noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps( | |
args, noise, latents, noise_scheduler, accelerator.device, dit_dtype | |
) | |
weighting = compute_loss_weighting_for_sd3( | |
args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype | |
) | |
# ensure guidance_scale in args is float | |
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype) | |
# ensure the hidden state will require grad | |
if args.gradient_checkpointing: | |
noisy_model_input.requires_grad_(True) | |
guidance_vec.requires_grad_(True) | |
pos_emb_shape = latents.shape[1:] | |
if pos_emb_shape not in pos_embed_cache: | |
freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(accelerator.unwrap_model(transformer), latents.shape[2:]) | |
# freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype) | |
# freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype) | |
pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin) | |
else: | |
freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape] | |
# call DiT | |
latents = latents.to(device=accelerator.device, dtype=dit_dtype) | |
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype) | |
# timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype) | |
# llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype) | |
# llm_mask = llm_mask.to(device=accelerator.device) | |
# clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype) | |
with accelerator.autocast(): | |
model_pred = transformer( | |
noisy_model_input, | |
timesteps, | |
text_states=llm_embeds, | |
text_mask=llm_mask, | |
text_states_2=clip_embeds, | |
freqs_cos=freqs_cos, | |
freqs_sin=freqs_sin, | |
guidance=guidance_vec, | |
return_dict=False, | |
) | |
# flow matching loss | |
target = noise - latents | |
loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none") | |
if weighting is not None: | |
loss = loss * weighting | |
# loss = loss.mean([1, 2, 3]) | |
# # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. | |
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) | |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
# self.all_reduce_network(accelerator, network) # sync DDP grad manually | |
state = accelerate.PartialState() | |
if state.distributed_type != accelerate.DistributedType.NO: | |
for param in transformer.parameters(): | |
if param.grad is not None: | |
param.grad = accelerator.reduce(param.grad, reduction="mean") | |
if args.max_grad_norm != 0.0: | |
params_to_clip = accelerator.unwrap_model(transformer).parameters() | |
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad(set_to_none=True) | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
optimizer_eval_fn() | |
self.sample_images( | |
accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters | |
) | |
# 指定ステップごとにモデルを保存 | |
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step) | |
save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch) | |
if args.save_state: | |
train_utils.save_and_remove_state_stepwise(args, accelerator, global_step) | |
remove_step_no = train_utils.get_remove_step_no(args, global_step) | |
if remove_step_no is not None: | |
remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no) | |
remove_model(remove_ckpt_name) | |
optimizer_train_fn() | |
current_loss = loss.detach().item() | |
loss_recorder.add(epoch=epoch, step=step, loss=current_loss) | |
avr_loss: float = loss_recorder.moving_average | |
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} | |
progress_bar.set_postfix(**logs) | |
if len(accelerator.trackers) > 0: | |
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} | |
accelerator.log(logs, step=global_step) | |
if global_step >= args.max_train_steps: | |
break | |
if len(accelerator.trackers) > 0: | |
logs = {"loss/epoch": loss_recorder.moving_average} | |
accelerator.log(logs, step=epoch + 1) | |
accelerator.wait_for_everyone() | |
# 指定エポックごとにモデルを保存 | |
optimizer_eval_fn() | |
if args.save_every_n_epochs is not None: | |
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs | |
if is_main_process and saving: | |
ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1) | |
save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1) | |
remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1) | |
if remove_epoch_no is not None: | |
remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no) | |
remove_model(remove_ckpt_name) | |
if args.save_state: | |
train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) | |
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters) | |
optimizer_train_fn() | |
# end of epoch | |
if is_main_process: | |
transformer = accelerator.unwrap_model(transformer) | |
accelerator.end_training() | |
optimizer_eval_fn() | |
if args.save_state or args.save_state_on_train_end: | |
train_utils.save_state_on_train_end(args, accelerator) | |
if is_main_process: | |
ckpt_name = train_utils.get_last_ckpt_name(args.output_name) | |
save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True) | |
logger.info("model saved.") | |
def setup_parser() -> argparse.ArgumentParser: | |
def int_or_float(value): | |
if value.endswith("%"): | |
try: | |
return float(value[:-1]) / 100.0 | |
except ValueError: | |
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") | |
try: | |
float_value = float(value) | |
if float_value >= 1 and float_value.is_integer(): | |
return int(value) | |
return float(value) | |
except ValueError: | |
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") | |
parser = argparse.ArgumentParser() | |
# general settings | |
parser.add_argument( | |
"--config_file", | |
type=str, | |
default=None, | |
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", | |
) | |
parser.add_argument( | |
"--dataset_config", | |
type=pathlib.Path, | |
default=None, | |
required=True, | |
help="config file for dataset / データセットの設定ファイル", | |
) | |
# training settings | |
parser.add_argument( | |
"--sdpa", | |
action="store_true", | |
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", | |
) | |
parser.add_argument( | |
"--flash_attn", | |
action="store_true", | |
help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要", | |
) | |
parser.add_argument( | |
"--sage_attn", | |
action="store_true", | |
help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要", | |
) | |
parser.add_argument( | |
"--xformers", | |
action="store_true", | |
help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要", | |
) | |
parser.add_argument( | |
"--split_attn", | |
action="store_true", | |
help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)" | |
" / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)", | |
) | |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") | |
parser.add_argument( | |
"--max_train_epochs", | |
type=int, | |
default=None, | |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", | |
) | |
parser.add_argument( | |
"--max_data_loader_n_workers", | |
type=int, | |
default=8, | |
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", | |
) | |
parser.add_argument( | |
"--persistent_data_loader_workers", | |
action="store_true", | |
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", | |
) | |
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") | |
parser.add_argument( | |
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", | |
) | |
parser.add_argument( | |
"--mixed_precision", | |
type=str, | |
default="no", | |
choices=["no", "fp16", "bf16"], | |
help="use mixed precision / 混合精度を使う場合、その精度", | |
) | |
parser.add_argument( | |
"--trainable_modules", | |
nargs='+', | |
default=".", | |
help='Enter a list of trainable modules' | |
) | |
parser.add_argument( | |
"--logging_dir", | |
type=str, | |
default=None, | |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", | |
) | |
parser.add_argument( | |
"--log_with", | |
type=str, | |
default=None, | |
choices=["tensorboard", "wandb", "all"], | |
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", | |
) | |
parser.add_argument( | |
"--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" | |
) | |
parser.add_argument( | |
"--log_tracker_name", | |
type=str, | |
default=None, | |
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", | |
) | |
parser.add_argument( | |
"--wandb_run_name", | |
type=str, | |
default=None, | |
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", | |
) | |
parser.add_argument( | |
"--log_tracker_config", | |
type=str, | |
default=None, | |
help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", | |
) | |
parser.add_argument( | |
"--wandb_api_key", | |
type=str, | |
default=None, | |
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", | |
) | |
parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") | |
parser.add_argument( | |
"--ddp_timeout", | |
type=int, | |
default=None, | |
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", | |
) | |
parser.add_argument( | |
"--ddp_gradient_as_bucket_view", | |
action="store_true", | |
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", | |
) | |
parser.add_argument( | |
"--ddp_static_graph", | |
action="store_true", | |
help="enable static_graph for DDP / DDPでstatic_graphを有効にする", | |
) | |
parser.add_argument( | |
"--sample_every_n_steps", | |
type=int, | |
default=None, | |
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", | |
) | |
parser.add_argument( | |
"--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" | |
) | |
parser.add_argument( | |
"--sample_every_n_epochs", | |
type=int, | |
default=None, | |
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", | |
) | |
parser.add_argument( | |
"--sample_prompts", | |
type=str, | |
default=None, | |
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", | |
) | |
# optimizer and lr scheduler settings | |
parser.add_argument( | |
"--optimizer_type", | |
type=str, | |
default="", | |
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. " | |
"Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ", | |
) | |
parser.add_argument( | |
"--optimizer_args", | |
type=str, | |
default=None, | |
nargs="*", | |
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', | |
) | |
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") | |
parser.add_argument( | |
"--max_grad_norm", | |
default=1.0, | |
type=float, | |
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", | |
) | |
parser.add_argument( | |
"--lr_scheduler", | |
type=str, | |
default="constant", | |
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", | |
) | |
parser.add_argument( | |
"--lr_warmup_steps", | |
type=int_or_float, | |
default=0, | |
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" | |
" / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", | |
) | |
parser.add_argument( | |
"--lr_decay_steps", | |
type=int_or_float, | |
default=0, | |
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" | |
" / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", | |
) | |
parser.add_argument( | |
"--lr_scheduler_num_cycles", | |
type=int, | |
default=1, | |
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", | |
) | |
parser.add_argument( | |
"--lr_scheduler_power", | |
type=float, | |
default=1, | |
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", | |
) | |
parser.add_argument( | |
"--lr_scheduler_timescale", | |
type=int, | |
default=None, | |
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" | |
+ " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", | |
) | |
parser.add_argument( | |
"--lr_scheduler_min_lr_ratio", | |
type=float, | |
default=None, | |
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" | |
+ " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", | |
) | |
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") | |
parser.add_argument( | |
"--lr_scheduler_args", | |
type=str, | |
default=None, | |
nargs="*", | |
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', | |
) | |
# model settings | |
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス") | |
parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16") | |
parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス") | |
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") | |
parser.add_argument( | |
"--vae_tiling", | |
action="store_true", | |
help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled." | |
" / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。", | |
) | |
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") | |
parser.add_argument( | |
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" | |
) | |
parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ") | |
parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ") | |
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16") | |
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う") | |
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") | |
parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する") | |
parser.add_argument( | |
"--blocks_to_swap", | |
type=int, | |
default=None, | |
help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX", | |
) | |
parser.add_argument( | |
"--img_in_txt_in_offloading", | |
action="store_true", | |
help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする", | |
) | |
# parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers") | |
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.") | |
parser.add_argument( | |
"--timestep_sampling", | |
choices=["sigma", "uniform", "sigmoid", "shift"], | |
default="sigma", | |
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." | |
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", | |
) | |
parser.add_argument( | |
"--discrete_flow_shift", | |
type=float, | |
default=1.0, | |
help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。", | |
) | |
parser.add_argument( | |
"--sigmoid_scale", | |
type=float, | |
default=1.0, | |
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。', | |
) | |
parser.add_argument( | |
"--weighting_scheme", | |
type=str, | |
default="none", | |
choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"], | |
help="weighting scheme for timestep distribution. Default is none" | |
" / タイムステップ分布の重み付けスキーム、デフォルトはnone", | |
) | |
parser.add_argument( | |
"--logit_mean", | |
type=float, | |
default=0.0, | |
help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", | |
) | |
parser.add_argument( | |
"--logit_std", | |
type=float, | |
default=1.0, | |
help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", | |
) | |
parser.add_argument( | |
"--mode_scale", | |
type=float, | |
default=1.29, | |
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", | |
) | |
parser.add_argument( | |
"--min_timestep", | |
type=int, | |
default=None, | |
help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", | |
) | |
parser.add_argument( | |
"--max_timestep", | |
type=int, | |
default=None, | |
help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", | |
) | |
# save and load settings | |
parser.add_argument( | |
"--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" | |
) | |
parser.add_argument( | |
"--output_name", | |
type=str, | |
default=None, | |
required=True, | |
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名", | |
) | |
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") | |
parser.add_argument( | |
"--save_every_n_epochs", | |
type=int, | |
default=None, | |
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", | |
) | |
parser.add_argument( | |
"--save_every_n_steps", | |
type=int, | |
default=None, | |
help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", | |
) | |
parser.add_argument( | |
"--save_last_n_epochs", | |
type=int, | |
default=None, | |
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", | |
) | |
parser.add_argument( | |
"--save_last_n_epochs_state", | |
type=int, | |
default=None, | |
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", | |
) | |
parser.add_argument( | |
"--save_last_n_steps", | |
type=int, | |
default=None, | |
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", | |
) | |
parser.add_argument( | |
"--save_last_n_steps_state", | |
type=int, | |
default=None, | |
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", | |
) | |
parser.add_argument( | |
"--save_state", | |
action="store_true", | |
help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", | |
) | |
parser.add_argument( | |
"--save_state_on_train_end", | |
action="store_true", | |
help="save training state (including optimizer states etc.) on train end even if --save_state is not specified" | |
" / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する", | |
) | |
# SAI Model spec | |
parser.add_argument( | |
"--metadata_title", | |
type=str, | |
default=None, | |
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", | |
) | |
parser.add_argument( | |
"--metadata_author", | |
type=str, | |
default=None, | |
help="author name for model metadata / メタデータに書き込まれるモデル作者名", | |
) | |
parser.add_argument( | |
"--metadata_description", | |
type=str, | |
default=None, | |
help="description for model metadata / メタデータに書き込まれるモデル説明", | |
) | |
parser.add_argument( | |
"--metadata_license", | |
type=str, | |
default=None, | |
help="license for model metadata / メタデータに書き込まれるモデルライセンス", | |
) | |
parser.add_argument( | |
"--metadata_tags", | |
type=str, | |
default=None, | |
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", | |
) | |
# huggingface settings | |
parser.add_argument( | |
"--huggingface_repo_id", | |
type=str, | |
default=None, | |
help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", | |
) | |
parser.add_argument( | |
"--huggingface_repo_type", | |
type=str, | |
default=None, | |
help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", | |
) | |
parser.add_argument( | |
"--huggingface_path_in_repo", | |
type=str, | |
default=None, | |
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", | |
) | |
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") | |
parser.add_argument( | |
"--huggingface_repo_visibility", | |
type=str, | |
default=None, | |
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", | |
) | |
parser.add_argument( | |
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" | |
) | |
parser.add_argument( | |
"--resume_from_huggingface", | |
action="store_true", | |
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", | |
) | |
parser.add_argument( | |
"--async_upload", | |
action="store_true", | |
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", | |
) | |
return parser | |
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): | |
if not args.config_file: | |
return args | |
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file | |
if not os.path.exists(config_path): | |
logger.info(f"{config_path} not found.") | |
exit(1) | |
logger.info(f"Loading settings from {config_path}...") | |
with open(config_path, "r", encoding="utf-8") as f: | |
config_dict = toml.load(f) | |
# combine all sections into one | |
ignore_nesting_dict = {} | |
for section_name, section_dict in config_dict.items(): | |
# if value is not dict, save key and value as is | |
if not isinstance(section_dict, dict): | |
ignore_nesting_dict[section_name] = section_dict | |
continue | |
# if value is dict, save all key and value into one dict | |
for key, value in section_dict.items(): | |
ignore_nesting_dict[key] = value | |
config_args = argparse.Namespace(**ignore_nesting_dict) | |
args = parser.parse_args(namespace=config_args) | |
args.config_file = os.path.splitext(args.config_file)[0] | |
logger.info(args.config_file) | |
return args | |
if __name__ == "__main__": | |
parser = setup_parser() | |
args = parser.parse_args() | |
args = read_config_from_file(args, parser) | |
trainer = FineTuningTrainer() | |
trainer.train(args) | |