""" Module for training the 2D clutter filtering model with L2 loss. """ import os import argparse import json import numpy as np from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau from utils import * from Model_ClutterFilter2D import clutter_filter_2D from DataGen import DataGen def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config): DtaGenTr_prm = { 'dim': config["network_prm"]["input_dim"], 'in_dir': in_ids_tr, 'out_dir': out_ids_tr, 'id_list': np.arange(len(in_ids_tr)), 'batch_size': config["learning_prm"]["batch_size"], 'tr_phase': True} DtaGenVal_prm = { 'dim': config["network_prm"]["input_dim"], 'in_dir': in_ids_val, 'out_dir': out_ids_val, 'id_list': np.arange(len(in_ids_val)), 'batch_size': config["learning_prm"]["batch_size"], 'tr_phase': True} tr_gen = DataGen(**DtaGenTr_prm) val_gen = DataGen(**DtaGenVal_prm) return tr_gen, val_gen def model_chkpnt(val_subject, te_subject, weight_dir, config): weight_name = ( f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}' f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}' f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}') filepath = (weight_dir + '/'+ weight_name + '_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" + '_valloss' + "{val_loss:.5f}" + ".hdf5") model_checkpoint = ModelCheckpoint(filepath=filepath, monitor="val_loss", verbose=0, save_best_only=True) reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4, min_lr=1e-7) return model_checkpoint, reduce_lr def main(config): in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config) weight_dir = create_weight_dir(val_subject, te_subject, config) tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config) model = clutter_filter_2D(**config) model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config) model.fit(tr_gen, validation_data=val_gen, epochs=config["learning_prm"]["n_epochs"], verbose=1, callbacks=[model_checkpoint, reduce_lr]) return None if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", help="path of the config file", default="config.json") args = parser.parse_args() assert os.path.isfile(args.config) with open(args.config, "r") as read_file: config = json.load(read_file) main(config)