Spaces:
Running
Running
File size: 6,071 Bytes
9ace58a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import numpy as np
import argparse
import os
import json
import sys
sys.path.append(os.path.join('..', '..'))
import bat_detect.train.train_utils as tu
def print_dataset_stats(data, split_name, classes_to_ignore):
print('\nSplit:', split_name)
print('Num files:', len(data))
class_cnts = {}
for dd in data:
for aa in dd['annotation']:
if aa['class'] not in classes_to_ignore:
if aa['class'] in class_cnts:
class_cnts[aa['class']] += 1
else:
class_cnts[aa['class']] = 1
if len(class_cnts) == 0:
class_names = []
else:
class_names = np.sort([*class_cnts]).tolist()
print('Class count:')
str_len = np.max([len(cc) for cc in class_names]) + 5
for ii, cc in enumerate(class_names):
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
return class_names
def load_file_names(file_name):
if os.path.isfile(file_name):
with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()]
for ff in files:
if ff.lower()[-3:] != 'wav':
print('Error: Filenames need to end in .wav - ', ff)
assert(False)
else:
print('Error: Input file not found - ', file_name)
assert(False)
return files
if __name__ == "__main__":
info_str = '\nBatDetect - Prepare Data for Finetuning\n'
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument('dataset_name', type=str, help='Name to call your dataset')
parser.add_argument('audio_dir', type=str, help='Input directory for audio')
parser.add_argument('ann_dir', type=str, help='Input directory for where the audio annotations are stored')
parser.add_argument('op_dir', type=str, help='Path where the train and test splits will be stored')
parser.add_argument('--percent_val', type=float, default=0.20,
help='Hold out this much data for validation. Should be number between 0 and 1')
parser.add_argument('--rand_seed', type=int, default=2001,
help='Random seed used for creating the validation split')
parser.add_argument('--train_file', type=str, default='',
help='Text file where each line is a wav file in train split')
parser.add_argument('--test_file', type=str, default='',
help='Text file where each line is a wav file in test split')
parser.add_argument('--input_class_names', type=str, default='',
help='Specify names of classes that you want to change. Separate with ";"')
parser.add_argument('--output_class_names', type=str, default='',
help='New class names to use instead. One to one mapping with "--input_class_names". \
Separate with ";"')
args = vars(parser.parse_args())
np.random.seed(args['rand_seed'])
classes_to_ignore = ['', ' ', 'Unknown', 'Not Bat']
generic_class = ['Bat']
events_of_interest = ['Echolocation']
if args['input_class_names'] != '' and args['output_class_names'] != '':
# change the names of the classes
ip_names = args['input_class_names'].split(';')
op_names = args['output_class_names'].split(';')
name_dict = dict(zip(ip_names, op_names))
else:
name_dict = False
# load annotations
data_all, _, _ = tu.load_set_of_anns({'ann_path': args['ann_dir'], 'wav_path': args['audio_dir']},
classes_to_ignore, events_of_interest, False, False,
list_of_anns=True, filter_issues=True, name_replace=name_dict)
print('Dataset name: ' + args['dataset_name'])
print('Audio directory: ' + args['audio_dir'])
print('Annotation directory: ' + args['ann_dir'])
print('Ouput directory: ' + args['op_dir'])
print('Num annotated files: ' + str(len(data_all)))
if args['train_file'] != '' and args['test_file'] != '':
# user has specifed the train / test split
train_files = load_file_names(args['train_file'])
test_files = load_file_names(args['test_file'])
file_names_all = [dd['id'] for dd in data_all]
train_inds = [file_names_all.index(ff) for ff in train_files if ff in file_names_all]
test_inds = [file_names_all.index(ff) for ff in test_files if ff in file_names_all]
else:
# split the data into train and test at the file level
num_exs = len(data_all)
test_inds = np.random.choice(np.arange(num_exs), int(num_exs*args['percent_val']), replace=False)
test_inds = np.sort(test_inds)
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
data_train = [data_all[ii] for ii in train_inds]
data_test = [data_all[ii] for ii in test_inds]
if not os.path.isdir(args['op_dir']):
os.makedirs(args['op_dir'])
op_name = os.path.join(args['op_dir'], args['dataset_name'])
op_name_train = op_name + '_TRAIN.json'
op_name_test = op_name + '_TEST.json'
class_un_train = print_dataset_stats(data_train, 'Train', classes_to_ignore)
class_un_test = print_dataset_stats(data_test, 'Test', classes_to_ignore)
if len(data_train) > 0 and len(data_test) > 0:
if class_un_train != class_un_test:
print('\nError: some classes are not in both the training and test sets.\
\nTry a different random seed "--rand_seed".')
assert False
print('\n')
if len(data_train) == 0:
print('No train annotations to save')
else:
print('Saving: ', op_name_train)
with open(op_name_train, 'w') as da:
json.dump(data_train, da, indent=2)
if len(data_test) == 0:
print('No test annotations to save')
else:
print('Saving: ', op_name_test)
with open(op_name_test, 'w') as da:
json.dump(data_test, da, indent=2)
|