Spaces:
Sleeping
Sleeping
import math | |
import re | |
from copy import deepcopy | |
from pathlib import Path | |
from typing import Dict | |
import lightning.pytorch as pl | |
import numpy as np | |
import torch | |
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE | |
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar | |
from lightning.pytorch.loggers import TensorBoardLogger | |
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only | |
from torch.optim.lr_scheduler import LambdaLR | |
from torch.utils.data.distributed import Sampler | |
import utils | |
from utils.hparams import hparams | |
# ==========LR schedulers========== | |
class RSQRTSchedule(object): | |
def __init__(self, optimizer): | |
super().__init__() | |
self.optimizer = optimizer | |
self.constant_lr = hparams['lr'] | |
self.warmup_updates = hparams['warmup_updates'] | |
self.hidden_size = hparams['hidden_size'] | |
self.lr = hparams['lr'] | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = self.lr | |
self.step(0) | |
def step(self, num_updates): | |
constant_lr = self.constant_lr | |
warmup = min(num_updates / self.warmup_updates, 1.0) | |
rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 | |
rsqrt_hidden = self.hidden_size ** -0.5 | |
self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = self.lr | |
return self.lr | |
def get_lr(self): | |
return self.optimizer.param_groups[0]['lr'] | |
class WarmupCosineSchedule(LambdaLR): | |
""" Linear warmup and then cosine decay. | |
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. | |
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. | |
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. | |
`eta_min` (default=0.0) corresponds to the minimum learning rate reached by the scheduler. | |
""" | |
def __init__(self, optimizer, warmup_steps, t_total, eta_min=0.0, cycles=.5, last_epoch=-1): | |
self.warmup_steps = warmup_steps | |
self.t_total = t_total | |
self.eta_min = eta_min | |
self.cycles = cycles | |
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) | |
def lr_lambda(self, step): | |
if step < self.warmup_steps: | |
return step / max(1.0, self.warmup_steps) | |
# progress after warmup | |
progress = (step - self.warmup_steps) / max(1, self.t_total - self.warmup_steps) | |
return max(self.eta_min, 0.5 * (1. + math.cos(math.pi * self.cycles * 2.0 * progress))) | |
# ==========Torch samplers========== | |
class DsBatchSampler(Sampler): | |
def __init__(self, dataset, max_batch_frames, max_batch_size, sub_indices=None, | |
num_replicas=None, rank=None, | |
required_batch_count_multiple=1, batch_by_size=True, sort_by_similar_size=True, | |
size_reversed=False, shuffle_sample=False, shuffle_batch=False, | |
disallow_empty_batch=True, pad_batch_assignment=True, seed=0, drop_last=False) -> None: | |
if rank >= num_replicas or rank < 0: | |
raise ValueError( | |
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") | |
self.dataset = dataset | |
self.max_batch_frames = max_batch_frames | |
self.max_batch_size = max_batch_size | |
self.sub_indices = sub_indices | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.required_batch_count_multiple = required_batch_count_multiple | |
self.batch_by_size = batch_by_size | |
self.sort_by_similar_size = sort_by_similar_size | |
self.size_reversed = size_reversed | |
self.shuffle_sample = shuffle_sample | |
self.shuffle_batch = shuffle_batch | |
self.disallow_empty_batch = disallow_empty_batch | |
self.pad_batch_assignment = pad_batch_assignment | |
self.seed = seed | |
self.drop_last = drop_last | |
self.epoch = 0 | |
self.batches = None | |
self.formed = None | |
def __form_batches(self): | |
if self.formed == self.epoch + self.seed: | |
return | |
rng = np.random.default_rng() | |
# Create indices | |
if self.shuffle_sample: | |
if self.sub_indices is not None: | |
rng.shuffle(self.sub_indices) | |
indices = np.array(self.sub_indices) | |
else: | |
indices = rng.permutation(len(self.dataset)) | |
if self.sort_by_similar_size: | |
grid = int(hparams['sampler_frame_count_grid']) | |
assert grid > 0 | |
sizes = (np.round(np.array(self.dataset.sizes)[indices] / grid) * grid).clip(grid, None) | |
sizes *= (-1 if self.size_reversed else 1) | |
indices = indices[np.argsort(sizes, kind='mergesort')] | |
indices = indices.tolist() | |
else: | |
indices = self.sub_indices if self.sub_indices is not None else list(range(len(self.dataset))) | |
# Batching | |
if self.batch_by_size: | |
batches = utils.batch_by_size( | |
indices, self.dataset.num_frames, | |
max_batch_frames=self.max_batch_frames, | |
max_batch_size=self.max_batch_size | |
) | |
else: | |
batches = [indices[i:i + self.max_batch_size] for i in range(0, len(indices), self.max_batch_size)] | |
if len(batches) < self.num_replicas and self.disallow_empty_batch: | |
raise RuntimeError("There is not enough batch to assign to each node.") | |
# Either drop_last or separate the leftovers. | |
floored_total_batch_count = (len(batches) // self.num_replicas) * self.num_replicas | |
if self.drop_last and len(batches) > floored_total_batch_count: | |
batches = batches[:floored_total_batch_count] | |
leftovers = [] | |
if len(batches) == 0: | |
raise RuntimeError("There is no batch left after dropping the last batch.") | |
elif self.shuffle_batch: | |
leftovers = (rng.permutation(len(batches) - floored_total_batch_count) + floored_total_batch_count).tolist() | |
else: | |
leftovers = list(range(floored_total_batch_count, len(batches))) | |
# Initial batch assignment to current rank. | |
batch_assignment = np.arange(floored_total_batch_count).reshape(-1, self.num_replicas).transpose() | |
if self.shuffle_batch: | |
batch_assignment = rng.permuted(batch_assignment, axis=0)[self.rank].tolist() | |
else: | |
batch_assignment = batch_assignment[self.rank].tolist() | |
# Assign leftovers or pad the batch assignment. | |
floored_batch_count = len(batch_assignment) | |
if self.rank < len(leftovers): | |
batch_assignment.append(leftovers[self.rank]) | |
floored_batch_count += 1 | |
elif len(leftovers) > 0 and self.pad_batch_assignment: | |
if not batch_assignment: | |
raise RuntimeError("Cannot pad empty batch assignment.") | |
batch_assignment.append(batch_assignment[self.epoch % floored_batch_count]) | |
# Ensure the batch count is multiple of required_batch_count_multiple. | |
if self.required_batch_count_multiple > 1 and len(batch_assignment) % self.required_batch_count_multiple != 0: | |
ceiled_batch_count = math.ceil( | |
len(batch_assignment) / self.required_batch_count_multiple | |
) * self.required_batch_count_multiple | |
for i in range(ceiled_batch_count - len(batch_assignment)): | |
batch_assignment.append( | |
batch_assignment[(i + self.epoch * self.required_batch_count_multiple) % floored_batch_count]) | |
if batch_assignment: | |
self.batches = [deepcopy(batches[i]) for i in batch_assignment] | |
else: | |
self.batches = [[]] | |
self.formed = self.epoch + self.seed | |
del indices | |
del batches | |
del batch_assignment | |
def __iter__(self): | |
self.__form_batches() | |
return iter(self.batches) | |
def __len__(self): | |
self.__form_batches() | |
if self.batches is None: | |
raise RuntimeError("Batches are not initialized. Call __form_batches first.") | |
return len(self.batches) | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
self.__form_batches() | |
# ==========PL related========== | |
class DsModelCheckpoint(ModelCheckpoint): | |
def __init__( | |
self, | |
*args, | |
permanent_ckpt_start, | |
permanent_ckpt_interval, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.permanent_ckpt_start = permanent_ckpt_start or 0 | |
self.permanent_ckpt_interval = permanent_ckpt_interval or 0 | |
self.enable_permanent_ckpt = self.permanent_ckpt_start > 0 and self.permanent_ckpt_interval > 9 | |
self._verbose = self.verbose | |
self.verbose = False | |
def state_dict(self): | |
ret = super().state_dict() | |
ret.pop('dirpath') | |
return ret | |
def load_state_dict(self, state_dict) -> None: | |
super().load_state_dict(state_dict) | |
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
if trainer.lightning_module.skip_immediate_ckpt_save: | |
trainer.lightning_module.skip_immediate_ckpt_save = False | |
return | |
self.last_val_step = trainer.global_step | |
super().on_validation_end(trainer, pl_module) | |
def _update_best_and_save( | |
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, torch.Tensor] | |
) -> None: | |
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k | |
del_filepath = None | |
_op = max if self.mode == "min" else min | |
while len(self.best_k_models) > k and k > 0: | |
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | |
self.kth_value = self.best_k_models[self.kth_best_model_path] | |
del_filepath = self.kth_best_model_path | |
self.best_k_models.pop(del_filepath) | |
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath) | |
if del_filepath is not None and filepath != del_filepath: | |
self._remove_checkpoint(trainer, del_filepath) | |
if len(self.best_k_models) == k and k > 0: | |
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | |
self.kth_value = self.best_k_models[self.kth_best_model_path] | |
super()._update_best_and_save(current, trainer, monitor_candidates) | |
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: | |
filepath = (Path(self.dirpath) / Path(filepath).name).resolve() | |
super()._save_checkpoint(trainer, str(filepath)) | |
if self._verbose: | |
relative_path = filepath | |
# Avoid using `is_relative_to` because Python 3.8 does not support this | |
if Path('.').resolve() in filepath.parents: | |
relative_path = filepath.relative_to(Path('.').resolve()) | |
rank_zero_info(f'Checkpoint {relative_path} saved.') | |
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str): | |
filepath = (Path(self.dirpath) / Path(filepath).name).resolve() | |
relative_path = filepath | |
# Avoid using `is_relative_to` because Python 3.8 does not support this | |
if Path('.').resolve() in filepath.parents: | |
relative_path = filepath.relative_to(Path('.').resolve()) | |
search = re.search(r'steps_\d+', relative_path.stem) | |
if search: | |
step = int(search.group(0)[6:]) | |
if self.enable_permanent_ckpt and \ | |
step >= self.permanent_ckpt_start and \ | |
(step - self.permanent_ckpt_start) % self.permanent_ckpt_interval == 0: | |
rank_zero_info(f'Checkpoint {relative_path} is now permanent.') | |
return | |
super()._remove_checkpoint(trainer, filepath) | |
if self._verbose: | |
rank_zero_info(f'Removed checkpoint {relative_path}.') | |
def get_latest_checkpoint_path(work_dir): | |
if not isinstance(work_dir, Path): | |
work_dir = Path(work_dir) | |
if not work_dir.exists(): | |
return None | |
last_step = -1 | |
last_ckpt_name = None | |
for ckpt in work_dir.glob('model_ckpt_steps_*.ckpt'): | |
search = re.search(r'steps_\d+', ckpt.name) | |
if search: | |
step = int(search.group(0)[6:]) | |
if step > last_step: | |
last_step = step | |
last_ckpt_name = str(ckpt) | |
return last_ckpt_name if last_ckpt_name is not None else None | |
class DsTQDMProgressBar(TQDMProgressBar): | |
def __init__(self, refresh_rate: int = 1, process_position: int = 0, show_steps: bool = True): | |
super().__init__(refresh_rate, process_position) | |
self.show_steps = show_steps | |
def get_metrics(self, trainer, model): | |
items = super().get_metrics(trainer, model) | |
if 'batch_size' in items: | |
items['batch_size'] = int(items['batch_size']) | |
if self.show_steps: | |
items['steps'] = str(trainer.global_step) | |
for k, v in items.items(): | |
if isinstance(v, float): | |
if np.isnan(v): | |
items[k] = 'nan' | |
elif 0.001 <= v < 10: | |
items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-') | |
elif 0.00001 <= v < 0.001: | |
if len(np.format_float_positional(v, unique=True, precision=8, trim='-')) > 8: | |
items[k] = np.format_float_scientific(v, precision=3, unique=True, min_digits=2, trim='-') | |
else: | |
items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-') | |
elif v < 0.00001: | |
items[k] = np.format_float_scientific(v, precision=3, unique=True, min_digits=2, trim='-') | |
items.pop("v_num", None) | |
return items | |
class DsTensorBoardLogger(TensorBoardLogger): | |
def all_rank_experiment(self): | |
if rank_zero_only.rank == 0: | |
return self.experiment | |
if hasattr(self, "_all_rank_experiment") and self._all_rank_experiment is not None: | |
return self._all_rank_experiment | |
assert rank_zero_only.rank != 0 | |
if self.root_dir: | |
self._fs.makedirs(self.root_dir, exist_ok=True) | |
if _TENSORBOARD_AVAILABLE: | |
from torch.utils.tensorboard import SummaryWriter | |
else: | |
from tensorboardX import SummaryWriter # type: ignore[no-redef] | |
self._all_rank_experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) | |
return self._all_rank_experiment | |
def finalize(self, status: str) -> None: | |
if rank_zero_only.rank == 0: | |
super().finalize(status) | |
elif hasattr(self, "_all_rank_experiment") and self._all_rank_experiment is not None: | |
self.all_rank_experiment.flush() | |
self.all_rank_experiment.close() | |
def __getstate__(self): | |
state = super().__getstate__() | |
if "_all_rank_experiment" in state: | |
del state["_all_rank_experiment"] | |
return state | |
def get_strategy( | |
devices="auto", | |
num_nodes=1, | |
accelerator="auto", | |
strategy={"name": "auto"}, | |
precision=None, | |
): | |
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device | |
from lightning.pytorch.accelerators import AcceleratorRegistry | |
from lightning.pytorch.accelerators.cuda import CUDAAccelerator | |
from lightning.pytorch.accelerators.mps import MPSAccelerator | |
from lightning.pytorch.strategies import Strategy, SingleDeviceStrategy, StrategyRegistry | |
from lightning.pytorch.trainer.connectors import accelerator_connector | |
from lightning.pytorch.utilities.rank_zero import rank_zero_warn | |
class _DsAcceleratorConnector(accelerator_connector._AcceleratorConnector): | |
def __init__(self) -> None: | |
accelerator_connector._register_external_accelerators_and_strategies() | |
self._registered_strategies = StrategyRegistry.available_strategies() | |
self._accelerator_types = AcceleratorRegistry.available_accelerators() | |
self._parallel_devices = [] | |
self._check_config_and_set_final_flags( | |
strategy=strategy["name"], | |
accelerator=accelerator, | |
precision=precision, | |
plugins=[], | |
sync_batchnorm=False, | |
) | |
if self._accelerator_flag == "auto": | |
self._accelerator_flag = self._choose_auto_accelerator() | |
elif self._accelerator_flag == "gpu": | |
self._accelerator_flag = self._choose_gpu_accelerator_backend() | |
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) | |
self._set_parallel_devices_and_init_accelerator() | |
if self._strategy_flag == "auto": | |
self._strategy_flag = self._choose_strategy() | |
self._check_strategy_and_fallback() | |
self._init_strategy() | |
for k in ["colossalai", "bagua", "hpu", "hpu_parallel", "hpu_single", "ipu", "ipu_strategy"]: | |
if k in StrategyRegistry: | |
StrategyRegistry.remove(k) | |
def _init_strategy(self) -> None: | |
assert isinstance(self._strategy_flag, (str, Strategy)) | |
if isinstance(self._strategy_flag, str): | |
if self._strategy_flag not in StrategyRegistry: | |
available_names = ", ".join(sorted(StrategyRegistry.available_strategies())) or "none" | |
raise KeyError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}") | |
data = StrategyRegistry[self._strategy_flag] | |
params = {} | |
# Replicate additional logic for _choose_strategy when dealing with single device strategies | |
if issubclass(data["strategy"], SingleDeviceStrategy): | |
if self._accelerator_flag == "hpu": | |
params = {"device": torch.device("hpu")} | |
elif self._accelerator_flag == "tpu": | |
params = {"device": self._parallel_devices[0]} | |
elif data["strategy"] is SingleDeviceStrategy: | |
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( | |
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") | |
): | |
params = {"device": _determine_root_gpu_device(self._parallel_devices)} | |
else: | |
params = {"device": "cpu"} | |
else: | |
raise NotImplementedError | |
params.update(data["init_params"]) | |
params.update({k: v for k, v in strategy.items() if k != "name"}) | |
self.strategy = data["strategy"](**utils.filter_kwargs(params, data["strategy"])) | |
elif isinstance(self._strategy_flag, SingleDeviceStrategy): | |
params = {"device": self._strategy_flag.root_device} | |
params.update({k: v for k, v in strategy.items() if k != "name"}) | |
self.strategy = self._strategy_flag.__class__(**utils.filter_kwargs(params, self._strategy_flag.__class__)) | |
else: | |
rank_zero_warn( | |
f"Inferred strategy {self._strategy_flag.__class__.__name__} cannot take custom configurations." | |
f"To use custom configurations, please specify the strategy name explicitly." | |
) | |
self.strategy = self._strategy_flag | |
return _DsAcceleratorConnector().strategy | |