ASesYusuf1's picture
Upload 132 files
30ea38f verified
###
# Author: Kai Li
# Date: 2024-01-22 01:16:22
# Email: lk21@mails.tsinghua.edu.cn
# LastEditTime: 2024-01-24 00:05:10
###
import json
from typing import Any, Dict, List, Optional, Tuple
import os
from omegaconf import OmegaConf
import argparse
import pytorch_lightning as pl
import torch
torch.set_float32_matmul_precision("highest")
import hydra
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
# from pytorch_lightning.loggers import Logger
from omegaconf import DictConfig
import look2hear.system
import look2hear.datas
import look2hear.losses
from look2hear.utils import RankedLogger, instantiate, print_only
import warnings
warnings.filterwarnings("ignore")
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
# instantiate datamodule
print_only(f"Instantiating datamodule <{cfg.datas._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas)
# instantiate model
print_only(f"Instantiating AudioNet <{cfg.model._target_}>")
model: torch.nn.Module = hydra.utils.instantiate(cfg.model)
print_only(f"Instantiating Discriminator <{cfg.discriminator._target_}>")
discriminator: torch.nn.Module = hydra.utils.instantiate(cfg.discriminator)
# instantiate optimizer
print_only(f"Instantiating optimizer <{cfg.optimizer_g._target_}>")
optimizer_g: torch.optim = hydra.utils.instantiate(cfg.optimizer_g, params=model.parameters())
optimizer_d: torch.optim = hydra.utils.instantiate(cfg.optimizer_d, params=discriminator.parameters())
# optimizer: torch.optim = torch.optim.Adam(model.parameters(), lr=cfg.optimizer.lr)
# instantiate scheduler
print_only(f"Instantiating scheduler <{cfg.scheduler_g._target_}>")
scheduler_g: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_g, optimizer=optimizer_g)
scheduler_d: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_d, optimizer=optimizer_d)
# instantiate loss
print_only(f"Instantiating loss <{cfg.loss_g._target_}>")
loss_g: torch.nn.Module = hydra.utils.instantiate(cfg.loss_g)
loss_d: torch.nn.Module = hydra.utils.instantiate(cfg.loss_d)
losses = {
"g": loss_g,
"d": loss_d
}
# instantiate metrics
print_only(f"Instantiating metrics <{cfg.metrics._target_}>")
metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics)
# instantiate system
print_only(f"Instantiating system <{cfg.system._target_}>")
system: LightningModule = hydra.utils.instantiate(
cfg.system,
model=model,
discriminator=discriminator,
loss_func=losses,
metrics=metrics,
optimizer=[optimizer_g, optimizer_d],
scheduler=[scheduler_g, scheduler_d]
)
# instantiate callbacks
callbacks: List[Callback] = []
if cfg.get("early_stopping"):
print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>")
callbacks.append(hydra.utils.instantiate(cfg.early_stopping))
if cfg.get("checkpoint"):
print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>")
checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint)
callbacks.append(checkpoint)
# instantiate logger
print_only(f"Instantiating logger <{cfg.logger._target_}>")
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True)
logger = hydra.utils.instantiate(cfg.logger)
# instantiate trainer
print_only(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
strategy=DDPStrategy(find_unused_parameters=True),
)
trainer.fit(system, datamodule=datamodule)
print_only("Training finished!")
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
state_dict = torch.load(checkpoint.best_model_path)
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu()
to_save = system.audio_model.serialize()
torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"))
import wandb
if wandb.run:
print_only("Closing wandb!")
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--conf_dir",
default="local/conf.yml",
help="Full path to save best validation model",
)
args = parser.parse_args()
cfg = OmegaConf.load(args.conf_dir)
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True)
# 保存配置到新的文件
OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml"))
train(cfg)