Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import pathlib | |
import shutil | |
import sys | |
from typing import Dict | |
import matplotlib | |
import utils | |
from utils.text_encoder import TokenTextEncoder | |
matplotlib.use('Agg') | |
import torch.utils.data | |
from torchmetrics import Metric, MeanMetric | |
import lightning.pytorch as pl | |
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only | |
from basics.base_module import CategorizedModule | |
from utils.hparams import hparams | |
from utils.training_utils import ( | |
DsModelCheckpoint, DsTQDMProgressBar, | |
DsBatchSampler, DsTensorBoardLogger, | |
get_latest_checkpoint_path, get_strategy | |
) | |
from utils.phoneme_utils import locate_dictionary, build_phoneme_list | |
torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) | |
log_format = '%(asctime)s %(message)s' | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO, | |
format=log_format, datefmt='%m/%d %I:%M:%S %p') | |
class BaseTask(pl.LightningModule): | |
""" | |
Base class for training tasks. | |
1. *load_ckpt*: | |
load checkpoint; | |
2. *training_step*: | |
record and log the loss; | |
3. *optimizer_step*: | |
run backwards step; | |
4. *start*: | |
load training configs, backup code, log to tensorboard, start training; | |
5. *configure_ddp* and *init_ddp_connection*: | |
start parallel training. | |
Subclasses should define: | |
1. *build_model*, *build_optimizer*, *build_scheduler*: | |
how to build the model, the optimizer and the training scheduler; | |
2. *_training_step*: | |
one training step of the model; | |
3. *on_validation_end* and *_on_validation_end*: | |
postprocess the validation output. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.max_batch_frames = hparams['max_batch_frames'] | |
self.max_batch_size = hparams['max_batch_size'] | |
self.max_val_batch_frames = hparams['max_val_batch_frames'] | |
if self.max_val_batch_frames == -1: | |
hparams['max_val_batch_frames'] = self.max_val_batch_frames = self.max_batch_frames | |
self.max_val_batch_size = hparams['max_val_batch_size'] | |
if self.max_val_batch_size == -1: | |
hparams['max_val_batch_size'] = self.max_val_batch_size = self.max_batch_size | |
self.training_sampler = None | |
self.skip_immediate_validation = False | |
self.skip_immediate_ckpt_save = False | |
self.phone_encoder = self.build_phone_encoder() | |
self.build_model() | |
self.valid_losses: Dict[str, Metric] = {} | |
self.valid_metrics: Dict[str, Metric] = {} | |
def _finish_init(self): | |
self.register_validation_loss('total_loss') | |
self.build_losses_and_metrics() | |
assert len(self.valid_losses) > 0, "No validation loss registered. Please check your configuration file." | |
########### | |
# Training, validation and testing | |
########### | |
def setup(self, stage): | |
self.train_dataset = self.dataset_cls('train') | |
self.valid_dataset = self.dataset_cls('valid') | |
self.num_replicas = (self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1) | |
def get_need_freeze_state_dict_key(self, model_state_dict) -> list: | |
key_list = [] | |
for i in hparams['frozen_params']: | |
for j in model_state_dict: | |
if j.startswith(i): | |
key_list.append(j) | |
return list(set(key_list)) | |
def freeze_params(self) -> None: | |
model_state_dict = self.state_dict().keys() | |
freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) | |
for i in freeze_key: | |
params=self.get_parameter(i) | |
params.requires_grad = False | |
def unfreeze_all_params(self) -> None: | |
for i in self.model.parameters(): | |
i.requires_grad = True | |
def load_finetune_ckpt( | |
self, state_dict | |
) -> None: | |
adapt_shapes = hparams['finetune_strict_shapes'] | |
if not adapt_shapes: | |
cur_model_state_dict = self.state_dict() | |
unmatched_keys = [] | |
for key, param in state_dict.items(): | |
if key in cur_model_state_dict: | |
new_param = cur_model_state_dict[key] | |
if new_param.shape != param.shape: | |
unmatched_keys.append(key) | |
print('| Unmatched keys: ', key, new_param.shape, param.shape) | |
for key in unmatched_keys: | |
del state_dict[key] | |
self.load_state_dict(state_dict, strict=False) | |
def load_pre_train_model(self): | |
pre_train_ckpt_path = hparams['finetune_ckpt_path'] | |
blacklist = hparams['finetune_ignored_params'] | |
# whitelist=hparams['pre_train_whitelist'] | |
if blacklist is None: | |
blacklist = [] | |
# if whitelist is None: | |
# raise RuntimeError("") | |
if pre_train_ckpt_path is not None: | |
ckpt = torch.load(pre_train_ckpt_path) | |
# if ckpt.get('category') is None: | |
# raise RuntimeError("") | |
if isinstance(self.model, CategorizedModule): | |
self.model.check_category(ckpt.get('category')) | |
state_dict = {} | |
for i in ckpt['state_dict']: | |
# if 'diffusion' in i: | |
# if i in rrrr: | |
# continue | |
skip = False | |
for b in blacklist: | |
if i.startswith(b): | |
skip = True | |
break | |
if skip: | |
continue | |
state_dict[i] = ckpt['state_dict'][i] | |
print(i) | |
return state_dict | |
else: | |
raise RuntimeError("") | |
def build_phone_encoder(): | |
phone_list = build_phoneme_list() | |
return TokenTextEncoder(vocab_list=phone_list) | |
def _build_model(self): | |
raise NotImplementedError() | |
def build_model(self): | |
self.model = self._build_model() | |
# utils.load_warp(self) | |
self.unfreeze_all_params() | |
if hparams['freezing_enabled']: | |
self.freeze_params() | |
if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: | |
self.load_finetune_ckpt(self.load_pre_train_model()) | |
self.print_arch() | |
def print_arch(self): | |
utils.print_arch(self.model) | |
def build_losses_and_metrics(self): | |
raise NotImplementedError() | |
def register_validation_metric(self, name: str, metric: Metric): | |
assert isinstance(metric, Metric) | |
self.valid_metrics[name] = metric | |
def register_validation_loss(self, name: str, Aggregator: Metric = MeanMetric): | |
assert issubclass(Aggregator, Metric) | |
self.valid_losses[name] = Aggregator() | |
def run_model(self, sample, infer=False): | |
""" | |
steps: | |
1. run the full model | |
2. calculate losses if not infer | |
""" | |
raise NotImplementedError() | |
def on_train_epoch_start(self): | |
if self.training_sampler is not None: | |
self.training_sampler.set_epoch(self.current_epoch) | |
def _training_step(self, sample): | |
""" | |
:return: total loss: torch.Tensor, loss_log: dict, other_log: dict | |
""" | |
losses = self.run_model(sample) | |
total_loss = sum(losses.values()) | |
return total_loss, {**losses, 'batch_size': float(sample['size'])} | |
def training_step(self, sample, batch_idx): | |
total_loss, log_outputs = self._training_step(sample) | |
# logs to progress bar | |
self.log_dict(log_outputs, prog_bar=True, logger=False, on_step=True, on_epoch=False) | |
self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False) | |
# logs to tensorboard | |
if self.global_step % hparams['log_interval'] == 0: | |
tb_log = {f'training/{k}': v for k, v in log_outputs.items()} | |
tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0] | |
self.logger.log_metrics(tb_log, step=self.global_step) | |
return total_loss | |
# def on_before_optimizer_step(self, *args, **kwargs): | |
# self.log_dict(grad_norm(self, norm_type=2)) | |
def _on_validation_start(self): | |
pass | |
def on_validation_start(self): | |
if self.skip_immediate_validation: | |
rank_zero_debug("Skip validation") | |
return | |
self._on_validation_start() | |
for metric in self.valid_losses.values(): | |
metric.to(self.device) | |
metric.reset() | |
for metric in self.valid_metrics.values(): | |
metric.to(self.device) | |
metric.reset() | |
def _validation_step(self, sample, batch_idx): | |
""" | |
:param sample: | |
:param batch_idx: | |
:return: loss_log: dict, weight: int | |
""" | |
raise NotImplementedError() | |
def validation_step(self, sample, batch_idx): | |
""" | |
:param sample: | |
:param batch_idx: | |
""" | |
if self.skip_immediate_validation: | |
rank_zero_debug("Skip validation") | |
return | |
if sample['size'] > 0: | |
with torch.autocast(self.device.type, enabled=False): | |
losses, weight = self._validation_step(sample, batch_idx) | |
losses = { | |
'total_loss': sum(losses.values()), | |
**losses | |
} | |
for k, v in losses.items(): | |
self.valid_losses[k].update(v, weight=weight) | |
def _on_validation_epoch_end(self): | |
pass | |
def on_validation_epoch_end(self): | |
if self.skip_immediate_validation: | |
self.skip_immediate_validation = False | |
self.skip_immediate_ckpt_save = True | |
return | |
self._on_validation_epoch_end() | |
loss_vals = {k: v.compute() for k, v in self.valid_losses.items()} | |
metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()} | |
self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) | |
self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step) | |
self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step) | |
# noinspection PyMethodMayBeStatic | |
def build_scheduler(self, optimizer): | |
from utils import build_lr_scheduler_from_config | |
scheduler_args = hparams['lr_scheduler_args'] | |
assert scheduler_args['scheduler_cls'] != '' | |
scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) | |
return scheduler | |
# noinspection PyMethodMayBeStatic | |
def build_optimizer(self, model): | |
from utils import build_object_from_class_name | |
optimizer_args = hparams['optimizer_args'] | |
assert optimizer_args['optimizer_cls'] != '' | |
if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: | |
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) | |
optimizer = build_object_from_class_name( | |
optimizer_args['optimizer_cls'], | |
torch.optim.Optimizer, | |
model.parameters(), | |
**optimizer_args | |
) | |
return optimizer | |
def configure_optimizers(self): | |
optm = self.build_optimizer(self.model) | |
scheduler = self.build_scheduler(optm) | |
if scheduler is None: | |
return optm | |
return { | |
"optimizer": optm, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"interval": "step", | |
"frequency": 1 | |
} | |
} | |
def train_dataloader(self): | |
self.training_sampler = DsBatchSampler( | |
self.train_dataset, | |
max_batch_frames=self.max_batch_frames, | |
max_batch_size=self.max_batch_size, | |
num_replicas=self.num_replicas, | |
rank=self.global_rank, | |
sort_by_similar_size=hparams['sort_by_len'], | |
size_reversed=True, | |
required_batch_count_multiple=hparams['accumulate_grad_batches'], | |
shuffle_sample=True, | |
shuffle_batch=True | |
) | |
return torch.utils.data.DataLoader( | |
self.train_dataset, | |
collate_fn=self.train_dataset.collater, | |
batch_sampler=self.training_sampler, | |
num_workers=hparams['ds_workers'], | |
prefetch_factor=hparams['dataloader_prefetch_factor'], | |
pin_memory=True, | |
persistent_workers=True | |
) | |
def val_dataloader(self): | |
sampler = DsBatchSampler( | |
self.valid_dataset, | |
max_batch_frames=self.max_val_batch_frames, | |
max_batch_size=self.max_val_batch_size, | |
num_replicas=self.num_replicas, | |
rank=self.global_rank, | |
shuffle_sample=False, | |
shuffle_batch=False, | |
disallow_empty_batch=False, | |
pad_batch_assignment=False | |
) | |
return torch.utils.data.DataLoader( | |
self.valid_dataset, | |
collate_fn=self.valid_dataset.collater, | |
batch_sampler=sampler, | |
num_workers=hparams['ds_workers'], | |
prefetch_factor=hparams['dataloader_prefetch_factor'], | |
persistent_workers=True | |
) | |
def test_dataloader(self): | |
return self.val_dataloader() | |
def on_test_start(self): | |
self.on_validation_start() | |
def test_step(self, sample, batch_idx): | |
return self.validation_step(sample, batch_idx) | |
def on_test_end(self): | |
return self.on_validation_end() | |
########### | |
# Running configuration | |
########### | |
def start(cls): | |
task = cls() | |
# if pre_train is not None: | |
# task.load_state_dict(pre_train,strict=False) | |
# print("load success-------------------------------------------------------------------") | |
work_dir = pathlib.Path(hparams['work_dir']) | |
trainer = pl.Trainer( | |
accelerator=hparams['pl_trainer_accelerator'], | |
devices=hparams['pl_trainer_devices'], | |
num_nodes=hparams['pl_trainer_num_nodes'], | |
strategy=get_strategy( | |
hparams['pl_trainer_devices'], | |
hparams['pl_trainer_num_nodes'], | |
hparams['pl_trainer_accelerator'], | |
hparams['pl_trainer_strategy'], | |
hparams['pl_trainer_precision'], | |
), | |
precision=hparams['pl_trainer_precision'], | |
callbacks=[ | |
DsModelCheckpoint( | |
dirpath=work_dir, | |
filename='model_ckpt_steps_{step}', | |
auto_insert_metric_name=False, | |
monitor='step', | |
mode='max', | |
save_last=False, | |
# every_n_train_steps=hparams['val_check_interval'], | |
save_top_k=hparams['num_ckpt_keep'], | |
permanent_ckpt_start=hparams['permanent_ckpt_start'], | |
permanent_ckpt_interval=hparams['permanent_ckpt_interval'], | |
verbose=True | |
), | |
# LearningRateMonitor(logging_interval='step'), | |
DsTQDMProgressBar(), | |
], | |
logger=DsTensorBoardLogger( | |
save_dir=str(work_dir), | |
name='lightning_logs', | |
version='latest' | |
), | |
gradient_clip_val=hparams['clip_grad_norm'], | |
val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'], | |
# so this is global_steps | |
check_val_every_n_epoch=None, | |
log_every_n_steps=1, | |
max_steps=hparams['max_updates'], | |
use_distributed_sampler=False, | |
num_sanity_val_steps=hparams['num_sanity_val_steps'], | |
accumulate_grad_batches=hparams['accumulate_grad_batches'] | |
) | |
if not hparams['infer']: # train | |
def train_payload_copy(): | |
# Copy spk_map.json and dictionary.txt to work dir | |
binary_dir = pathlib.Path(hparams['binary_data_dir']) | |
spk_map = work_dir / 'spk_map.json' | |
spk_map_src = binary_dir / 'spk_map.json' | |
if not spk_map.exists() and spk_map_src.exists(): | |
shutil.copy(spk_map_src, spk_map) | |
print(f'| Copied spk map to {spk_map}.') | |
dictionary = work_dir / 'dictionary.txt' | |
dict_src = binary_dir / 'dictionary.txt' | |
if not dictionary.exists(): | |
if dict_src.exists(): | |
shutil.copy(dict_src, dictionary) | |
else: | |
shutil.copy(locate_dictionary(), dictionary) | |
print(f'| Copied dictionary to {dictionary}.') | |
train_payload_copy() | |
trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir)) | |
else: | |
trainer.test(task) | |
def on_save_checkpoint(self, checkpoint): | |
if isinstance(self.model, CategorizedModule): | |
checkpoint['category'] = self.model.category | |
checkpoint['trainer_stage'] = self.trainer.state.stage.value | |
def on_load_checkpoint(self, checkpoint): | |
from lightning.pytorch.trainer.states import RunningStage | |
from utils import simulate_lr_scheduler | |
if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value: | |
self.skip_immediate_validation = True | |
optimizer_args = hparams['optimizer_args'] | |
scheduler_args = hparams['lr_scheduler_args'] | |
if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: | |
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) | |
if checkpoint.get('optimizer_states', None): | |
opt_states = checkpoint['optimizer_states'] | |
assert len(opt_states) == 1 # only support one optimizer | |
opt_state = opt_states[0] | |
for param_group in opt_state['param_groups']: | |
for k, v in optimizer_args.items(): | |
if k in param_group and param_group[k] != v: | |
if 'lr_schedulers' in checkpoint and checkpoint['lr_schedulers'] and k == 'lr': | |
continue | |
rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}') | |
param_group[k] = v | |
if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']: | |
rank_zero_info( | |
f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}' | |
) | |
param_group['initial_lr'] = optimizer_args['lr'] | |
if checkpoint.get('lr_schedulers', None): | |
assert checkpoint.get('optimizer_states', False) | |
assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler | |
checkpoint['lr_schedulers'][0] = simulate_lr_scheduler( | |
optimizer_args, scheduler_args, | |
step_count=checkpoint['global_step'], | |
num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups']) | |
) | |
for param_group, new_lr in zip( | |
checkpoint['optimizer_states'][0]['param_groups'], | |
checkpoint['lr_schedulers'][0]['_last_lr'], | |
): | |
if param_group['lr'] != new_lr: | |
rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') | |
param_group['lr'] = new_lr | |