Spaces:
Running
Running
File size: 8,913 Bytes
9ace58a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import argparse
import glob
import sys
sys.path.append(os.path.join('..', '..'))
import bat_detect.train.train_model as tm
import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.train.losses as losses
import bat_detect.detector.parameters as parameters
import bat_detect.detector.models as models
import bat_detect.detector.post_process as pp
import bat_detect.utils.plot_utils as pu
import bat_detect.utils.detector_utils as du
if __name__ == "__main__":
info_str = '\nBatDetect - Finetune Model\n'
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument('audio_path', type=str, help='Input directory for audio')
parser.add_argument('train_ann_path', type=str,
help='Path to where train annotation file is stored')
parser.add_argument('test_ann_path', type=str,
help='Path to where test annotation file is stored')
parser.add_argument('model_path', type=str,
help='Path to pretrained model')
parser.add_argument('--op_model_name', type=str, default='',
help='Path and name for finetuned model')
parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs',
help='Number of finetuning epochs')
parser.add_argument('--finetune_only_last_layer', action='store_true',
help='Only train final layers')
parser.add_argument('--train_from_scratch', action='store_true',
help='Do not use pretrained weights')
parser.add_argument('--do_not_save_images', action='store_false',
help='Do not save images at the end of training')
parser.add_argument('--notes', type=str, default='',
help='Notes to save in text file')
args = vars(parser.parse_args())
params = parameters.get_params(True, '../../experiments/')
if torch.cuda.is_available():
params['device'] = 'cuda'
else:
params['device'] = 'cpu'
print('\nNote, this will be a lot faster if you use computer with a GPU.\n')
print('\nAudio directory: ' + args['audio_path'])
print('Train file: ' + args['train_ann_path'])
print('Test file: ' + args['test_ann_path'])
print('Loading model: ' + args['model_path'])
dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '')
if args['train_from_scratch']:
print('\nTraining model from scratch i.e. not using pretrained weights')
model, params_train = du.load_model(args['model_path'], False)
else:
model, params_train = du.load_model(args['model_path'], True)
model.to(params['device'])
params['num_epochs'] = args['num_epochs']
if args['op_model_name'] != '':
params['model_file_name'] = args['op_model_name']
classes_to_ignore = params['classes_to_ignore']+params['generic_class']
# save notes file
params['notes'] = args['notes']
if args['notes'] != '':
tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes'])
# load train annotations
train_sets = []
train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path']))
params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])]
print('\nTrain set:')
data_train, params['class_names'], params['class_inv_freq'] = \
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'])
print('Number of files', len(data_train))
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
params['class_names_short'] = tu.get_short_class_names(params['class_names'])
# load test annotations
test_sets = []
test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path']))
params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])]
print('\nTest set:')
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'])
print('Number of files', len(data_test))
# train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
shuffle=True, num_workers=params['num_workers'], pin_memory=True)
# test loader - batch size of one because of variable file length
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
shuffle=False, num_workers=params['num_workers'], pin_memory=True)
inputs_train = next(iter(train_loader))
params['ip_height'] = inputs_train['spec'].shape[2]
print('\ntrain batch size :', inputs_train['spec'].shape)
assert(params_train['model_name'] == 'Net2DFast')
print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n')
# set the number of output classes
num_filts = model.conv_classes_op.in_channels
k_size = model.conv_classes_op.kernel_size
pad = model.conv_classes_op.padding
model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad)
model.conv_classes_op.to(params['device'])
if args['finetune_only_last_layer']:
print('\nOnly finetuning the final layers.\n')
train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op']
train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i]
for name, param in model.named_parameters():
if name in train_layers:
param.requires_grad = True
else:
param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
if params['train_loss'] == 'mse':
det_criterion = losses.mse_loss
elif params['train_loss'] == 'focal':
det_criterion = losses.focal_loss
# plotting
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
# main train loop
for epoch in range(0, params['num_epochs']+1):
train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
if epoch % params['num_eval_epochs'] == 0:
# detection accuracy on test set
test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params)
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
# save finetuned model
print('saving model to: ' + params['model_file_name'])
op_state = {'epoch': epoch + 1,
'state_dict': model.state_dict(),
'params' : params}
torch.save(op_state, params['model_file_name'])
# save an image with associated prediction for each batch in the test set
if not args['do_not_save_images']:
tm.save_images_batch(model, test_loader, params)
|