File size: 1,532 Bytes
6477265 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
"""
Module for testing the 3D clutter filtering model.
"""
import os
import argparse
import json
import numpy as np
import pandas as pd
from utils import *
from Model_ClutterFilter3D import clutter_filter_3D
from DataGen import DataGen
from Error_analysis import compute_mae
def data_generation(in_ids_te, out_ids_te, config):
DtaGenTe_prm = {
'dim': config["network_prm"]["input_dim"],
'in_dir': in_ids_te,
'out_dir': out_ids_te,
'id_list': np.arange(len(in_ids_te)),
'batch_size': config["learning_prm"]["batch_size"],
'tr_phase': False}
return DataGen(**DtaGenTe_prm)
def main(config):
in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
te_gen = data_generation(in_ids_te, out_ids_te, config)
model = clutter_filter_3D(**config)
weight_dir = create_weight_dir(val_subject, te_subject, config)
model.load_weights(
os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
results_te = model.predict_generator(te_gen, verbose=2)
df_errors = compute_mae(in_ids_te, results_te)
df_errors.to_csv(
os.path.join(weight_dir, config["weight_name"] + ".csv"))
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", help="path of the config file", default="config.json")
args = parser.parse_args()
assert os.path.isfile(args.config)
with open(args.config, "r") as read_file:
config = json.load(read_file)
main(config) |