import torch import pytorch_lightning as pl from torch.optim.lr_scheduler import ReduceLROnPlateau from collections.abc import MutableMapping # from speechbrain.processing.speech_augmentation import SpeedPerturb def flatten_dict(d, parent_key="", sep="_"): """Flattens a dictionary into a single-level dictionary while preserving parent keys. Taken from `SO `_ Args: d (MutableMapping): Dictionary to be flattened. parent_key (str): String to use as a prefix to all subsequent keys. sep (str): String to use as a separator between two key levels. Returns: dict: Single-level dictionary, flattened. """ items = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) class AudioLightningModuleMultiDecoder(pl.LightningModule): def __init__( self, audio_model=None, video_model=None, optimizer=None, loss_func=None, train_loader=None, val_loader=None, test_loader=None, scheduler=None, config=None, ): super().__init__() self.audio_model = audio_model self.video_model = video_model self.optimizer = optimizer self.loss_func = loss_func self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.scheduler = scheduler self.config = {} if config is None else config # Speed Aug # self.speedperturb = SpeedPerturb( # self.config["datamodule"]["data_config"]["sample_rate"], # speeds=[95, 100, 105], # perturb_prob=1.0 # ) # Save lightning"s AttributeDict under self.hparams self.default_monitor = "val_loss/dataloader_idx_0" self.save_hyperparameters(self.config_to_hparams(self.config)) # self.print(self.audio_model) self.validation_step_outputs = [] self.test_step_outputs = [] def forward(self, wav, mouth=None): """Applies forward pass of the model. Returns: :class:`torch.Tensor` """ if self.training: return self.audio_model(wav) else: ests, num_blocks = self.audio_model(wav) assert ests.shape[0] % num_blocks == 0 batch_size = int(ests.shape[0] / num_blocks) return ests[-batch_size:] def training_step(self, batch, batch_nb): mixtures, targets, _ = batch new_targets = [] min_len = -1 # if self.config["training"]["SpeedAug"] == True: # with torch.no_grad(): # for i in range(targets.shape[1]): # new_target = self.speedperturb(targets[:, i, :]) # new_targets.append(new_target) # if i == 0: # min_len = new_target.shape[-1] # else: # if new_target.shape[-1] < min_len: # min_len = new_target.shape[-1] # targets = torch.zeros( # targets.shape[0], # targets.shape[1], # min_len, # device=targets.device, # dtype=torch.float, # ) # for i, new_target in enumerate(new_targets): # targets[:, i, :] = new_targets[i][:, 0:min_len] # mixtures = targets.sum(1) # print(mixtures.shape) est_sources, num_blocks = self(mixtures) assert est_sources.shape[0] % num_blocks == 0 batch_size = int(est_sources.shape[0] / num_blocks) ests_sources_each_block = [] for i in range(num_blocks): ests_sources_each_block.append(est_sources[i * batch_size : (i + 1) * batch_size]) loss = self.loss_func["train"](ests_sources_each_block, targets) self.log( "train_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True, logger=True, ) return {"loss": loss} def validation_step(self, batch, batch_nb, dataloader_idx): # cal val loss if dataloader_idx == 0: mixtures, targets, _ = batch # print(mixtures.shape) est_sources = self(mixtures) loss = self.loss_func["val"](est_sources, targets) self.log( "val_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True, logger=True, ) self.validation_step_outputs.append(loss) return {"val_loss": loss} # cal test loss if (self.trainer.current_epoch) % 10 == 0 and dataloader_idx == 1: mixtures, targets, _ = batch # print(mixtures.shape) est_sources = self(mixtures) tloss = self.loss_func["val"](est_sources, targets) self.log( "test_loss", tloss, on_epoch=True, prog_bar=True, sync_dist=True, logger=True, ) self.test_step_outputs.append(tloss) return {"test_loss": tloss} def on_validation_epoch_end(self): # val avg_loss = torch.stack(self.validation_step_outputs).mean() val_loss = torch.mean(self.all_gather(avg_loss)) self.log( "lr", self.optimizer.param_groups[0]["lr"], on_epoch=True, prog_bar=True, sync_dist=True, ) self.logger.experiment.log( {"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch} ) self.logger.experiment.log( {"val_pit_sisnr": -val_loss, "epoch": self.current_epoch} ) # test if (self.trainer.current_epoch) % 10 == 0: avg_loss = torch.stack(self.test_step_outputs).mean() test_loss = torch.mean(self.all_gather(avg_loss)) self.logger.experiment.log( {"test_pit_sisnr": -test_loss, "epoch": self.current_epoch} ) self.validation_step_outputs.clear() # free memory self.test_step_outputs.clear() # free memory def configure_optimizers(self): """Initialize optimizers, batch-wise and epoch-wise schedulers.""" if self.scheduler is None: return self.optimizer if not isinstance(self.scheduler, (list, tuple)): self.scheduler = [self.scheduler] # support multiple schedulers epoch_schedulers = [] for sched in self.scheduler: if not isinstance(sched, dict): if isinstance(sched, ReduceLROnPlateau): sched = {"scheduler": sched, "monitor": self.default_monitor} epoch_schedulers.append(sched) else: sched.setdefault("monitor", self.default_monitor) sched.setdefault("frequency", 1) # Backward compat if sched["interval"] == "batch": sched["interval"] = "step" assert sched["interval"] in [ "epoch", "step", ], "Scheduler interval should be either step or epoch" epoch_schedulers.append(sched) return [self.optimizer], epoch_schedulers # def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # if metric is None: # scheduler.step() # else: # scheduler.step(metric) def train_dataloader(self): """Training dataloader""" return self.train_loader def val_dataloader(self): """Validation dataloader""" return [self.val_loader, self.test_loader] def on_save_checkpoint(self, checkpoint): """Overwrite if you want to save more things in the checkpoint.""" checkpoint["training_config"] = self.config return checkpoint @staticmethod def config_to_hparams(dic): """Sanitizes the config dict to be handled correctly by torch SummaryWriter. It flatten the config dict, converts ``None`` to ``"None"`` and any list and tuple into torch.Tensors. Args: dic (dict): Dictionary to be transformed. Returns: dict: Transformed dictionary. """ dic = flatten_dict(dic) for k, v in dic.items(): if v is None: dic[k] = str(v) elif isinstance(v, (list, tuple)): dic[k] = torch.tensor(v) return dic