File size: 8,913 Bytes
9ace58a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import argparse
import glob

import sys
sys.path.append(os.path.join('..', '..'))
import bat_detect.train.train_model as tm
import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.train.losses as losses

import bat_detect.detector.parameters as parameters
import bat_detect.detector.models as models
import bat_detect.detector.post_process as pp
import bat_detect.utils.plot_utils as pu
import bat_detect.utils.detector_utils as du


if __name__ == "__main__":

    info_str = '\nBatDetect - Finetune Model\n'

    print(info_str)
    parser = argparse.ArgumentParser()
    parser.add_argument('audio_path', type=str, help='Input directory for audio')
    parser.add_argument('train_ann_path', type=str,
                        help='Path to where train annotation file is stored')
    parser.add_argument('test_ann_path', type=str,
                        help='Path to where test annotation file is stored')
    parser.add_argument('model_path', type=str,
                        help='Path to pretrained model')
    parser.add_argument('--op_model_name', type=str, default='',
                        help='Path and name for finetuned model')
    parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs',
                         help='Number of finetuning epochs')
    parser.add_argument('--finetune_only_last_layer', action='store_true',
                         help='Only train final layers')
    parser.add_argument('--train_from_scratch', action='store_true',
                         help='Do not use pretrained weights')
    parser.add_argument('--do_not_save_images', action='store_false',
                         help='Do not save images at the end of training')
    parser.add_argument('--notes', type=str, default='',
                         help='Notes to save in text file')
    args = vars(parser.parse_args())

    params = parameters.get_params(True, '../../experiments/')
    if torch.cuda.is_available():
        params['device'] = 'cuda'
    else:
        params['device'] = 'cpu'
        print('\nNote, this will be a lot faster if you use computer with a GPU.\n')

    print('\nAudio directory:      ' + args['audio_path'])
    print('Train file:           ' + args['train_ann_path'])
    print('Test file:            ' + args['test_ann_path'])
    print('Loading model:        ' + args['model_path'])

    dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '')

    if args['train_from_scratch']:
        print('\nTraining model from scratch i.e. not using pretrained weights')
        model, params_train = du.load_model(args['model_path'], False)
    else:
        model, params_train = du.load_model(args['model_path'], True)
    model.to(params['device'])

    params['num_epochs'] = args['num_epochs']
    if args['op_model_name'] != '':
        params['model_file_name'] = args['op_model_name']
    classes_to_ignore = params['classes_to_ignore']+params['generic_class']

    # save notes file
    params['notes'] = args['notes']
    if args['notes'] != '':
        tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes'])


    # load train annotations
    train_sets = []
    train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path']))
    params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])]

    print('\nTrain set:')
    data_train, params['class_names'], params['class_inv_freq'] = \
        tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'])
    print('Number of files', len(data_train))

    params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
    params['class_names_short'] = tu.get_short_class_names(params['class_names'])

    # load test annotations
    test_sets = []
    test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path']))
    params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])]

    print('\nTest set:')
    data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'])
    print('Number of files', len(data_test))

    # train loader
    train_dataset = adl.AudioLoader(data_train, params, is_train=True)
    train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
                    shuffle=True, num_workers=params['num_workers'], pin_memory=True)

    # test loader - batch size of one because of variable file length
    test_dataset = adl.AudioLoader(data_test, params, is_train=False)
    test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=1,
                   shuffle=False, num_workers=params['num_workers'], pin_memory=True)

    inputs_train = next(iter(train_loader))
    params['ip_height'] = inputs_train['spec'].shape[2]
    print('\ntrain batch size :', inputs_train['spec'].shape)

    assert(params_train['model_name'] == 'Net2DFast')
    print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n')

    # set the number of output classes
    num_filts = model.conv_classes_op.in_channels
    k_size = model.conv_classes_op.kernel_size
    pad = model.conv_classes_op.padding
    model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad)
    model.conv_classes_op.to(params['device'])

    if args['finetune_only_last_layer']:
        print('\nOnly finetuning the final layers.\n')
        train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op']
        train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i]
        for name, param in model.named_parameters():
            if name in train_layers:
                param.requires_grad = True
            else:
                param.requires_grad = False

    optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
    scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
    if params['train_loss'] == 'mse':
        det_criterion = losses.mse_loss
    elif params['train_loss'] == 'focal':
        det_criterion = losses.focal_loss

    # plotting
    train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
                                  ['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
    test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
                                  ['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
    test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
                                  ['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
    test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
                                  params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])

    # main train loop
    for epoch in range(0, params['num_epochs']+1):

        train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
        train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])

        if epoch % params['num_eval_epochs'] == 0:
            # detection accuracy on test set
            test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params)
            test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
            test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
                                             test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
            test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
            pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)

            # save finetuned model
            print('saving model to: ' + params['model_file_name'])
            op_state = {'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'params' : params}
            torch.save(op_state, params['model_file_name'])


    # save an image with associated prediction for each batch in the test set
    if not args['do_not_save_images']:
        tm.save_images_batch(model, test_loader, params)