import logging import os import pathlib import shutil import sys from typing import Dict import matplotlib import utils from utils.text_encoder import TokenTextEncoder matplotlib.use('Agg') import torch.utils.data from torchmetrics import Metric, MeanMetric import lightning.pytorch as pl from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only from basics.base_module import CategorizedModule from utils.hparams import hparams from utils.training_utils import ( DsModelCheckpoint, DsTQDMProgressBar, DsBatchSampler, DsTensorBoardLogger, get_latest_checkpoint_path, get_strategy ) from utils.phoneme_utils import locate_dictionary, build_phoneme_list torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') class BaseTask(pl.LightningModule): """ Base class for training tasks. 1. *load_ckpt*: load checkpoint; 2. *training_step*: record and log the loss; 3. *optimizer_step*: run backwards step; 4. *start*: load training configs, backup code, log to tensorboard, start training; 5. *configure_ddp* and *init_ddp_connection*: start parallel training. Subclasses should define: 1. *build_model*, *build_optimizer*, *build_scheduler*: how to build the model, the optimizer and the training scheduler; 2. *_training_step*: one training step of the model; 3. *on_validation_end* and *_on_validation_end*: postprocess the validation output. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_batch_frames = hparams['max_batch_frames'] self.max_batch_size = hparams['max_batch_size'] self.max_val_batch_frames = hparams['max_val_batch_frames'] if self.max_val_batch_frames == -1: hparams['max_val_batch_frames'] = self.max_val_batch_frames = self.max_batch_frames self.max_val_batch_size = hparams['max_val_batch_size'] if self.max_val_batch_size == -1: hparams['max_val_batch_size'] = self.max_val_batch_size = self.max_batch_size self.training_sampler = None self.skip_immediate_validation = False self.skip_immediate_ckpt_save = False self.phone_encoder = self.build_phone_encoder() self.build_model() self.valid_losses: Dict[str, Metric] = {} self.valid_metrics: Dict[str, Metric] = {} def _finish_init(self): self.register_validation_loss('total_loss') self.build_losses_and_metrics() assert len(self.valid_losses) > 0, "No validation loss registered. Please check your configuration file." ########### # Training, validation and testing ########### def setup(self, stage): self.train_dataset = self.dataset_cls('train') self.valid_dataset = self.dataset_cls('valid') self.num_replicas = (self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1) def get_need_freeze_state_dict_key(self, model_state_dict) -> list: key_list = [] for i in hparams['frozen_params']: for j in model_state_dict: if j.startswith(i): key_list.append(j) return list(set(key_list)) def freeze_params(self) -> None: model_state_dict = self.state_dict().keys() freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) for i in freeze_key: params=self.get_parameter(i) params.requires_grad = False def unfreeze_all_params(self) -> None: for i in self.model.parameters(): i.requires_grad = True def load_finetune_ckpt( self, state_dict ) -> None: adapt_shapes = hparams['finetune_strict_shapes'] if not adapt_shapes: cur_model_state_dict = self.state_dict() unmatched_keys = [] for key, param in state_dict.items(): if key in cur_model_state_dict: new_param = cur_model_state_dict[key] if new_param.shape != param.shape: unmatched_keys.append(key) print('| Unmatched keys: ', key, new_param.shape, param.shape) for key in unmatched_keys: del state_dict[key] self.load_state_dict(state_dict, strict=False) def load_pre_train_model(self): pre_train_ckpt_path = hparams['finetune_ckpt_path'] blacklist = hparams['finetune_ignored_params'] # whitelist=hparams['pre_train_whitelist'] if blacklist is None: blacklist = [] # if whitelist is None: # raise RuntimeError("") if pre_train_ckpt_path is not None: ckpt = torch.load(pre_train_ckpt_path) # if ckpt.get('category') is None: # raise RuntimeError("") if isinstance(self.model, CategorizedModule): self.model.check_category(ckpt.get('category')) state_dict = {} for i in ckpt['state_dict']: # if 'diffusion' in i: # if i in rrrr: # continue skip = False for b in blacklist: if i.startswith(b): skip = True break if skip: continue state_dict[i] = ckpt['state_dict'][i] print(i) return state_dict else: raise RuntimeError("") @staticmethod def build_phone_encoder(): phone_list = build_phoneme_list() return TokenTextEncoder(vocab_list=phone_list) def _build_model(self): raise NotImplementedError() def build_model(self): self.model = self._build_model() # utils.load_warp(self) self.unfreeze_all_params() if hparams['freezing_enabled']: self.freeze_params() if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt(self.load_pre_train_model()) self.print_arch() @rank_zero_only def print_arch(self): utils.print_arch(self.model) def build_losses_and_metrics(self): raise NotImplementedError() def register_validation_metric(self, name: str, metric: Metric): assert isinstance(metric, Metric) self.valid_metrics[name] = metric def register_validation_loss(self, name: str, Aggregator: Metric = MeanMetric): assert issubclass(Aggregator, Metric) self.valid_losses[name] = Aggregator() def run_model(self, sample, infer=False): """ steps: 1. run the full model 2. calculate losses if not infer """ raise NotImplementedError() def on_train_epoch_start(self): if self.training_sampler is not None: self.training_sampler.set_epoch(self.current_epoch) def _training_step(self, sample): """ :return: total loss: torch.Tensor, loss_log: dict, other_log: dict """ losses = self.run_model(sample) total_loss = sum(losses.values()) return total_loss, {**losses, 'batch_size': float(sample['size'])} def training_step(self, sample, batch_idx): total_loss, log_outputs = self._training_step(sample) # logs to progress bar self.log_dict(log_outputs, prog_bar=True, logger=False, on_step=True, on_epoch=False) self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False) # logs to tensorboard if self.global_step % hparams['log_interval'] == 0: tb_log = {f'training/{k}': v for k, v in log_outputs.items()} tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0] self.logger.log_metrics(tb_log, step=self.global_step) return total_loss # def on_before_optimizer_step(self, *args, **kwargs): # self.log_dict(grad_norm(self, norm_type=2)) def _on_validation_start(self): pass def on_validation_start(self): if self.skip_immediate_validation: rank_zero_debug("Skip validation") return self._on_validation_start() for metric in self.valid_losses.values(): metric.to(self.device) metric.reset() for metric in self.valid_metrics.values(): metric.to(self.device) metric.reset() def _validation_step(self, sample, batch_idx): """ :param sample: :param batch_idx: :return: loss_log: dict, weight: int """ raise NotImplementedError() def validation_step(self, sample, batch_idx): """ :param sample: :param batch_idx: """ if self.skip_immediate_validation: rank_zero_debug("Skip validation") return if sample['size'] > 0: with torch.autocast(self.device.type, enabled=False): losses, weight = self._validation_step(sample, batch_idx) losses = { 'total_loss': sum(losses.values()), **losses } for k, v in losses.items(): self.valid_losses[k].update(v, weight=weight) def _on_validation_epoch_end(self): pass def on_validation_epoch_end(self): if self.skip_immediate_validation: self.skip_immediate_validation = False self.skip_immediate_ckpt_save = True return self._on_validation_epoch_end() loss_vals = {k: v.compute() for k, v in self.valid_losses.items()} metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()} self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step) self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step) # noinspection PyMethodMayBeStatic def build_scheduler(self, optimizer): from utils import build_lr_scheduler_from_config scheduler_args = hparams['lr_scheduler_args'] assert scheduler_args['scheduler_cls'] != '' scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) return scheduler # noinspection PyMethodMayBeStatic def build_optimizer(self, model): from utils import build_object_from_class_name optimizer_args = hparams['optimizer_args'] assert optimizer_args['optimizer_cls'] != '' if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], torch.optim.Optimizer, model.parameters(), **optimizer_args ) return optimizer def configure_optimizers(self): optm = self.build_optimizer(self.model) scheduler = self.build_scheduler(optm) if scheduler is None: return optm return { "optimizer": optm, "lr_scheduler": { "scheduler": scheduler, "interval": "step", "frequency": 1 } } def train_dataloader(self): self.training_sampler = DsBatchSampler( self.train_dataset, max_batch_frames=self.max_batch_frames, max_batch_size=self.max_batch_size, num_replicas=self.num_replicas, rank=self.global_rank, sort_by_similar_size=hparams['sort_by_len'], size_reversed=True, required_batch_count_multiple=hparams['accumulate_grad_batches'], shuffle_sample=True, shuffle_batch=True ) return torch.utils.data.DataLoader( self.train_dataset, collate_fn=self.train_dataset.collater, batch_sampler=self.training_sampler, num_workers=hparams['ds_workers'], prefetch_factor=hparams['dataloader_prefetch_factor'], pin_memory=True, persistent_workers=True ) def val_dataloader(self): sampler = DsBatchSampler( self.valid_dataset, max_batch_frames=self.max_val_batch_frames, max_batch_size=self.max_val_batch_size, num_replicas=self.num_replicas, rank=self.global_rank, shuffle_sample=False, shuffle_batch=False, disallow_empty_batch=False, pad_batch_assignment=False ) return torch.utils.data.DataLoader( self.valid_dataset, collate_fn=self.valid_dataset.collater, batch_sampler=sampler, num_workers=hparams['ds_workers'], prefetch_factor=hparams['dataloader_prefetch_factor'], persistent_workers=True ) def test_dataloader(self): return self.val_dataloader() def on_test_start(self): self.on_validation_start() def test_step(self, sample, batch_idx): return self.validation_step(sample, batch_idx) def on_test_end(self): return self.on_validation_end() ########### # Running configuration ########### @classmethod def start(cls): task = cls() # if pre_train is not None: # task.load_state_dict(pre_train,strict=False) # print("load success-------------------------------------------------------------------") work_dir = pathlib.Path(hparams['work_dir']) trainer = pl.Trainer( accelerator=hparams['pl_trainer_accelerator'], devices=hparams['pl_trainer_devices'], num_nodes=hparams['pl_trainer_num_nodes'], strategy=get_strategy( hparams['pl_trainer_devices'], hparams['pl_trainer_num_nodes'], hparams['pl_trainer_accelerator'], hparams['pl_trainer_strategy'], hparams['pl_trainer_precision'], ), precision=hparams['pl_trainer_precision'], callbacks=[ DsModelCheckpoint( dirpath=work_dir, filename='model_ckpt_steps_{step}', auto_insert_metric_name=False, monitor='step', mode='max', save_last=False, # every_n_train_steps=hparams['val_check_interval'], save_top_k=hparams['num_ckpt_keep'], permanent_ckpt_start=hparams['permanent_ckpt_start'], permanent_ckpt_interval=hparams['permanent_ckpt_interval'], verbose=True ), # LearningRateMonitor(logging_interval='step'), DsTQDMProgressBar(), ], logger=DsTensorBoardLogger( save_dir=str(work_dir), name='lightning_logs', version='latest' ), gradient_clip_val=hparams['clip_grad_norm'], val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'], # so this is global_steps check_val_every_n_epoch=None, log_every_n_steps=1, max_steps=hparams['max_updates'], use_distributed_sampler=False, num_sanity_val_steps=hparams['num_sanity_val_steps'], accumulate_grad_batches=hparams['accumulate_grad_batches'] ) if not hparams['infer']: # train @rank_zero_only def train_payload_copy(): # Copy spk_map.json and dictionary.txt to work dir binary_dir = pathlib.Path(hparams['binary_data_dir']) spk_map = work_dir / 'spk_map.json' spk_map_src = binary_dir / 'spk_map.json' if not spk_map.exists() and spk_map_src.exists(): shutil.copy(spk_map_src, spk_map) print(f'| Copied spk map to {spk_map}.') dictionary = work_dir / 'dictionary.txt' dict_src = binary_dir / 'dictionary.txt' if not dictionary.exists(): if dict_src.exists(): shutil.copy(dict_src, dictionary) else: shutil.copy(locate_dictionary(), dictionary) print(f'| Copied dictionary to {dictionary}.') train_payload_copy() trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir)) else: trainer.test(task) def on_save_checkpoint(self, checkpoint): if isinstance(self.model, CategorizedModule): checkpoint['category'] = self.model.category checkpoint['trainer_stage'] = self.trainer.state.stage.value def on_load_checkpoint(self, checkpoint): from lightning.pytorch.trainer.states import RunningStage from utils import simulate_lr_scheduler if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value: self.skip_immediate_validation = True optimizer_args = hparams['optimizer_args'] scheduler_args = hparams['lr_scheduler_args'] if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) if checkpoint.get('optimizer_states', None): opt_states = checkpoint['optimizer_states'] assert len(opt_states) == 1 # only support one optimizer opt_state = opt_states[0] for param_group in opt_state['param_groups']: for k, v in optimizer_args.items(): if k in param_group and param_group[k] != v: if 'lr_schedulers' in checkpoint and checkpoint['lr_schedulers'] and k == 'lr': continue rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}') param_group[k] = v if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']: rank_zero_info( f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}' ) param_group['initial_lr'] = optimizer_args['lr'] if checkpoint.get('lr_schedulers', None): assert checkpoint.get('optimizer_states', False) assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler checkpoint['lr_schedulers'][0] = simulate_lr_scheduler( optimizer_args, scheduler_args, step_count=checkpoint['global_step'], num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups']) ) for param_group, new_lr in zip( checkpoint['optimizer_states'][0]['param_groups'], checkpoint['lr_schedulers'][0]['_last_lr'], ): if param_group['lr'] != new_lr: rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') param_group['lr'] = new_lr