""" Module for testing the 3D clutter filtering model. """ import os import argparse import json import numpy as np import pandas as pd from utils import * from Model_ClutterFilter3D import clutter_filter_3D from DataGen import DataGen from Error_analysis import compute_mae def data_generation(in_ids_te, out_ids_te, config): DtaGenTe_prm = { 'dim': config["network_prm"]["input_dim"], 'in_dir': in_ids_te, 'out_dir': out_ids_te, 'id_list': np.arange(len(in_ids_te)), 'batch_size': config["learning_prm"]["batch_size"], 'tr_phase': False} return DataGen(**DtaGenTe_prm) def main(config): in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config) te_gen = data_generation(in_ids_te, out_ids_te, config) model = clutter_filter_3D(**config) weight_dir = create_weight_dir(val_subject, te_subject, config) model.load_weights( os.path.join(weight_dir, config["weight_name"] + ".hdf5")) results_te = model.predict_generator(te_gen, verbose=2) df_errors = compute_mae(in_ids_te, results_te) df_errors.to_csv( os.path.join(weight_dir, config["weight_name"] + ".csv")) return None if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", help="path of the config file", default="config.json") args = parser.parse_args() assert os.path.isfile(args.config) with open(args.config, "r") as read_file: config = json.load(read_file) main(config)