File size: 1,532 Bytes
6477265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
""" 
Module for testing the 3D clutter filtering model.
"""
import os
import argparse
import json
import numpy as np
import pandas as pd

from utils import *
from Model_ClutterFilter3D import clutter_filter_3D
from DataGen import DataGen
from Error_analysis import compute_mae
    
def data_generation(in_ids_te, out_ids_te, config):
    DtaGenTe_prm = {
        'dim': config["network_prm"]["input_dim"],
        'in_dir': in_ids_te,
        'out_dir': out_ids_te,
        'id_list': np.arange(len(in_ids_te)),
        'batch_size': config["learning_prm"]["batch_size"],
        'tr_phase': False} 
    return DataGen(**DtaGenTe_prm)

def main(config):
    in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
    te_gen = data_generation(in_ids_te, out_ids_te, config)
    model = clutter_filter_3D(**config)
    weight_dir = create_weight_dir(val_subject, te_subject, config)
    model.load_weights(
        os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
    results_te = model.predict_generator(te_gen, verbose=2)
    df_errors = compute_mae(in_ids_te, results_te)
    df_errors.to_csv(
        os.path.join(weight_dir, config["weight_name"] + ".csv"))
    return None

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", help="path of the config file", default="config.json")
    args = parser.parse_args()
    assert os.path.isfile(args.config)
    with open(args.config, "r") as read_file:
        config = json.load(read_file)
    main(config)