Spaces:
Running
Running
File size: 5,094 Bytes
30ea38f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
###
# Author: Kai Li
# Date: 2024-01-22 01:16:22
# Email: lk21@mails.tsinghua.edu.cn
# LastEditTime: 2024-01-24 00:05:10
###
import json
from typing import Any, Dict, List, Optional, Tuple
import os
from omegaconf import OmegaConf
import argparse
import pytorch_lightning as pl
import torch
torch.set_float32_matmul_precision("highest")
import hydra
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
# from pytorch_lightning.loggers import Logger
from omegaconf import DictConfig
import look2hear.system
import look2hear.datas
import look2hear.losses
from look2hear.utils import RankedLogger, instantiate, print_only
import warnings
warnings.filterwarnings("ignore")
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
# instantiate datamodule
print_only(f"Instantiating datamodule <{cfg.datas._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas)
# instantiate model
print_only(f"Instantiating AudioNet <{cfg.model._target_}>")
model: torch.nn.Module = hydra.utils.instantiate(cfg.model)
print_only(f"Instantiating Discriminator <{cfg.discriminator._target_}>")
discriminator: torch.nn.Module = hydra.utils.instantiate(cfg.discriminator)
# instantiate optimizer
print_only(f"Instantiating optimizer <{cfg.optimizer_g._target_}>")
optimizer_g: torch.optim = hydra.utils.instantiate(cfg.optimizer_g, params=model.parameters())
optimizer_d: torch.optim = hydra.utils.instantiate(cfg.optimizer_d, params=discriminator.parameters())
# optimizer: torch.optim = torch.optim.Adam(model.parameters(), lr=cfg.optimizer.lr)
# instantiate scheduler
print_only(f"Instantiating scheduler <{cfg.scheduler_g._target_}>")
scheduler_g: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_g, optimizer=optimizer_g)
scheduler_d: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_d, optimizer=optimizer_d)
# instantiate loss
print_only(f"Instantiating loss <{cfg.loss_g._target_}>")
loss_g: torch.nn.Module = hydra.utils.instantiate(cfg.loss_g)
loss_d: torch.nn.Module = hydra.utils.instantiate(cfg.loss_d)
losses = {
"g": loss_g,
"d": loss_d
}
# instantiate metrics
print_only(f"Instantiating metrics <{cfg.metrics._target_}>")
metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics)
# instantiate system
print_only(f"Instantiating system <{cfg.system._target_}>")
system: LightningModule = hydra.utils.instantiate(
cfg.system,
model=model,
discriminator=discriminator,
loss_func=losses,
metrics=metrics,
optimizer=[optimizer_g, optimizer_d],
scheduler=[scheduler_g, scheduler_d]
)
# instantiate callbacks
callbacks: List[Callback] = []
if cfg.get("early_stopping"):
print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>")
callbacks.append(hydra.utils.instantiate(cfg.early_stopping))
if cfg.get("checkpoint"):
print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>")
checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint)
callbacks.append(checkpoint)
# instantiate logger
print_only(f"Instantiating logger <{cfg.logger._target_}>")
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True)
logger = hydra.utils.instantiate(cfg.logger)
# instantiate trainer
print_only(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
strategy=DDPStrategy(find_unused_parameters=True),
)
trainer.fit(system, datamodule=datamodule)
print_only("Training finished!")
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
state_dict = torch.load(checkpoint.best_model_path)
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu()
to_save = system.audio_model.serialize()
torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"))
import wandb
if wandb.run:
print_only("Closing wandb!")
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--conf_dir",
default="local/conf.yml",
help="Full path to save best validation model",
)
args = parser.parse_args()
cfg = OmegaConf.load(args.conf_dir)
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True)
# 保存配置到新的文件
OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml"))
train(cfg)
|