""" Utility functions. """ import os import numpy as np def generate_tr_val_te_subject_ids(subject_list, val_subject_id): val_subject = subject_list[val_subject_id] te_subject = subject_list[val_subject_id-1] subject_list.remove(val_subject) subject_list.remove(te_subject) tr_subjects = subject_list return tr_subjects, val_subject, te_subject def generate_data_ids(data_dir, subject_list): in_ids, out_ids = [], [] vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor] for vendor in vendor_list: vendor_dir = os.path.join(data_dir, vendor) view_list = [view for view in os.listdir(vendor_dir) if '.' not in view] for view in view_list: view_dir = os.path.join(vendor_dir, view) subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject] for subject in subject_full_list: if subject in subject_list: subject_dir = os.path.join(view_dir, subject) org_data_dir = os.path.join(subject_dir, 'data_org') org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0]) clutter_list = [clutter for clutter in os.listdir(subject_dir) if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt'] and '.' not in clutter] for clutter in clutter_list: clutter_dir = os.path.join(subject_dir, clutter) clutter_ids = os.listdir(clutter_dir) clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir] in_ids += clutter_ids_dir out_ids += [org_data_id]*len(os.listdir(clutter_dir)) return in_ids, out_ids def id_preparation(config): tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids( subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"]) if config["tr_phase"]: in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects) in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject) return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject else: in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject) return in_ids_te, out_ids_te, te_subject, val_subject def create_weight_dir(val_subject, te_subject, config): weight_dir = os.path.join(config["paths"]["save_path"], "Weights", f"ValTeIDs_{val_subject}_{te_subject}") if not os.path.exists(weight_dir): os.makedirs(weight_dir) return weight_dir