MahdiTabassian's picture
Filtering models and example video clips
6477265
"""
Utility functions.
"""
import os
import numpy as np
def generate_tr_val_te_subject_ids(subject_list, val_subject_id):
val_subject = subject_list[val_subject_id]
te_subject = subject_list[val_subject_id-1]
subject_list.remove(val_subject)
subject_list.remove(te_subject)
tr_subjects = subject_list
return tr_subjects, val_subject, te_subject
def generate_data_ids(data_dir, subject_list):
in_ids, out_ids = [], []
vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor]
for vendor in vendor_list:
vendor_dir = os.path.join(data_dir, vendor)
view_list = [view for view in os.listdir(vendor_dir) if '.' not in view]
for view in view_list:
view_dir = os.path.join(vendor_dir, view)
subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject]
for subject in subject_full_list:
if subject in subject_list:
subject_dir = os.path.join(view_dir, subject)
org_data_dir = os.path.join(subject_dir, 'data_org')
org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0])
clutter_list = [clutter for clutter in os.listdir(subject_dir)
if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt']
and '.' not in clutter]
for clutter in clutter_list:
clutter_dir = os.path.join(subject_dir, clutter)
clutter_ids = os.listdir(clutter_dir)
clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir]
in_ids += clutter_ids_dir
out_ids += [org_data_id]*len(os.listdir(clutter_dir))
return in_ids, out_ids
def id_preparation(config):
tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids(
subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"])
if config["tr_phase"]:
in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects)
in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject)
return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject
else:
in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject)
return in_ids_te, out_ids_te, te_subject, val_subject
def create_weight_dir(val_subject, te_subject, config):
weight_dir = os.path.join(config["paths"]["save_path"],
"Weights", f"ValTeIDs_{val_subject}_{te_subject}")
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
return weight_dir