|
import os |
|
import sys |
|
import torch |
|
from torch import Tensor |
|
import argparse |
|
import json |
|
import look2hear.datas |
|
import look2hear.models |
|
import look2hear.system |
|
import look2hear.losses |
|
import look2hear.metrics |
|
import look2hear.utils |
|
from look2hear.system import make_optimizer |
|
from dataclasses import dataclass |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from torch.utils.data import DataLoader |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar |
|
from pytorch_lightning.callbacks.progress.rich_progress import * |
|
from rich.console import Console |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from pytorch_lightning.loggers.wandb import WandbLogger |
|
from pytorch_lightning.strategies.ddp import DDPStrategy |
|
from rich import print, reconfigure |
|
from collections.abc import MutableMapping |
|
from look2hear.utils import print_only, MyRichProgressBar, RichProgressBarTheme |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
import wandb |
|
wandb.login() |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--conf_dir", |
|
default="local/conf.yml", |
|
help="Full path to save best validation model", |
|
) |
|
|
|
def main(config): |
|
print_only( |
|
"Instantiating datamodule <{}>".format(config["datamodule"]["data_name"]) |
|
) |
|
datamodule: object = getattr(look2hear.datas, config["datamodule"]["data_name"])( |
|
**config["datamodule"]["data_config"] |
|
) |
|
datamodule.setup() |
|
|
|
train_loader, val_loader, test_loader = datamodule.make_loader |
|
|
|
|
|
print_only( |
|
"Instantiating AudioNet <{}>".format(config["audionet"]["audionet_name"]) |
|
) |
|
model = getattr(look2hear.models, config["audionet"]["audionet_name"])( |
|
sample_rate=config["datamodule"]["data_config"]["sample_rate"], |
|
**config["audionet"]["audionet_config"], |
|
) |
|
|
|
print_only("Instantiating Optimizer <{}>".format(config["optimizer"]["optim_name"])) |
|
optimizer = make_optimizer(model.parameters(), **config["optimizer"]) |
|
|
|
|
|
scheduler = None |
|
if config["scheduler"]["sche_name"]: |
|
print_only( |
|
"Instantiating Scheduler <{}>".format(config["scheduler"]["sche_name"]) |
|
) |
|
if config["scheduler"]["sche_name"] != "DPTNetScheduler": |
|
scheduler = getattr(torch.optim.lr_scheduler, config["scheduler"]["sche_name"])( |
|
optimizer=optimizer, **config["scheduler"]["sche_config"] |
|
) |
|
else: |
|
scheduler = { |
|
"scheduler": getattr(look2hear.system.schedulers, config["scheduler"]["sche_name"])( |
|
optimizer, len(train_loader) // config["datamodule"]["data_config"]["batch_size"], 64 |
|
), |
|
"interval": "step", |
|
} |
|
|
|
|
|
config["main_args"]["exp_dir"] = os.path.join( |
|
os.getcwd(), "Experiments", "checkpoint", config["exp"]["exp_name"] |
|
) |
|
exp_dir = config["main_args"]["exp_dir"] |
|
os.makedirs(exp_dir, exist_ok=True) |
|
conf_path = os.path.join(exp_dir, "conf.yml") |
|
with open(conf_path, "w") as outfile: |
|
yaml.safe_dump(config, outfile) |
|
|
|
|
|
print_only( |
|
"Instantiating Loss, Train <{}>, Val <{}>".format( |
|
config["loss"]["train"]["sdr_type"], config["loss"]["val"]["sdr_type"] |
|
) |
|
) |
|
loss_func = { |
|
"train": getattr(look2hear.losses, config["loss"]["train"]["loss_func"])( |
|
getattr(look2hear.losses, config["loss"]["train"]["sdr_type"]), |
|
**config["loss"]["train"]["config"], |
|
), |
|
"val": getattr(look2hear.losses, config["loss"]["val"]["loss_func"])( |
|
getattr(look2hear.losses, config["loss"]["val"]["sdr_type"]), |
|
**config["loss"]["val"]["config"], |
|
), |
|
} |
|
|
|
print_only("Instantiating System <{}>".format(config["training"]["system"])) |
|
system = getattr(look2hear.system, config["training"]["system"])( |
|
audio_model=model, |
|
loss_func=loss_func, |
|
optimizer=optimizer, |
|
train_loader=train_loader, |
|
val_loader=val_loader, |
|
test_loader=test_loader, |
|
scheduler=scheduler, |
|
config=config, |
|
) |
|
|
|
|
|
print_only("Instantiating ModelCheckpoint") |
|
callbacks = [] |
|
checkpoint_dir = os.path.join(exp_dir) |
|
checkpoint = ModelCheckpoint( |
|
checkpoint_dir, |
|
filename="{epoch}", |
|
monitor="val_loss/dataloader_idx_0", |
|
mode="min", |
|
save_top_k=5, |
|
verbose=True, |
|
save_last=True, |
|
) |
|
callbacks.append(checkpoint) |
|
|
|
if config["training"]["early_stop"]: |
|
print_only("Instantiating EarlyStopping") |
|
callbacks.append(EarlyStopping(**config["training"]["early_stop"])) |
|
callbacks.append(MyRichProgressBar(theme=RichProgressBarTheme())) |
|
|
|
|
|
gpus = config["training"]["gpus"] if torch.cuda.is_available() else None |
|
distributed_backend = "cuda" if torch.cuda.is_available() else None |
|
|
|
|
|
logger_dir = os.path.join(os.getcwd(), "Experiments", "tensorboard_logs") |
|
os.makedirs(os.path.join(logger_dir, config["exp"]["exp_name"]), exist_ok=True) |
|
|
|
comet_logger = WandbLogger( |
|
name=config["exp"]["exp_name"], |
|
save_dir=os.path.join(logger_dir, config["exp"]["exp_name"]), |
|
project="Real-work-dataset", |
|
|
|
) |
|
|
|
trainer = pl.Trainer( |
|
max_epochs=config["training"]["epochs"], |
|
callbacks=callbacks, |
|
default_root_dir=exp_dir, |
|
devices=gpus, |
|
accelerator=distributed_backend, |
|
strategy=DDPStrategy(find_unused_parameters=True), |
|
limit_train_batches=1.0, |
|
gradient_clip_val=5.0, |
|
logger=comet_logger, |
|
sync_batchnorm=True, |
|
|
|
|
|
|
|
|
|
) |
|
trainer.fit(system) |
|
print_only("Finished Training") |
|
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} |
|
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: |
|
json.dump(best_k, f, indent=0) |
|
|
|
state_dict = torch.load(checkpoint.best_model_path) |
|
system.load_state_dict(state_dict=state_dict["state_dict"]) |
|
system.cpu() |
|
|
|
to_save = system.audio_model.serialize() |
|
torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) |
|
|
|
|
|
if __name__ == "__main__": |
|
import yaml |
|
from pprint import pprint |
|
from look2hear.utils.parser_utils import ( |
|
prepare_parser_from_dict, |
|
parse_args_as_dict, |
|
) |
|
|
|
args = parser.parse_args() |
|
with open(args.conf_dir) as f: |
|
def_conf = yaml.safe_load(f) |
|
parser = prepare_parser_from_dict(def_conf, parser=parser) |
|
|
|
arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) |
|
|
|
main(arg_dic) |
|
|