File size: 3,790 Bytes
2502b35
9ace58a
 
 
 
2502b35
9ace58a
 
 
2502b35
 
9ace58a
 
 
 
 
 
73fd754
9ace58a
 
 
 
 
 
73fd754
 
9ace58a
73fd754
296340f
9ace58a
 
 
 
 
 
 
 
 
 
 
 
 
 
7475d7b
9ace58a
 
73fd754
 
 
 
 
 
 
 
 
9ace58a
 
73fd754
 
 
9ace58a
73fd754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ace58a
 
 
 
73fd754
296340f
9ace58a
 
 
 
 
296340f
9ace58a
 
 
 
 
 
296340f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import bat_detect.utils.detector_utils as du
import bat_detect.utils.audio_utils as au
import bat_detect.utils.plot_utils as viz


# setup the arguments
args = {}
args = du.get_default_bd_args()
args['detection_threshold'] = 0.3
args['time_expansion_factor'] = 1
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
max_duration = 2.0

# load the model
model, params = du.load_model(args['model_path'])


df =  gr.Dataframe(
        headers=["species", "time", "detection_prob", "species_prob"],
        datatype=["str", "str", "str", "str"],
        row_count=1,
        col_count=(4, "fixed"),
        label='Predictions'
    )
    
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
            ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
            ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]


def make_prediction(file_name=None, detection_threshold=0.3):
    
    if file_name is not None:
        audio_file = file_name
    else:
        return "You must provide an input audio file."
    
    if detection_threshold is not None and detection_threshold != '':
        args['detection_threshold'] = float(detection_threshold)
    
    # process the file to generate predictions
    results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
    
    anns = [ann for ann in results['pred_dict']['annotation']]
    clss = [aa['class'] for aa in anns]
    st_time = [aa['start_time'] for aa in anns]
    cls_prob = [aa['class_prob'] for aa in anns]
    det_prob = [aa['det_prob'] for aa in anns]
    data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
    
    df = pd.DataFrame(data=data)
    im = generate_results_image(audio_file, anns)
    
    return [df, im]


def generate_results_image(audio_file, anns): 
    
    # load audio
    sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], 
                           params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
    duration = audio.shape[0] / sampling_rate
        
    # generate spec
    spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)

    # create fig
    plt.close('all')
    fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
    spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
    viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
    plt.ylabel('Freq - kHz')
    plt.xlabel('Time - secs')
    plt.tight_layout()
    
    # convert fig to image
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))

    return im


descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
          "<br>This model is only trained on bat species from the UK. If the input " \
          "file is longer than 2 seconds, only the first 2 seconds will be processed." \
          "<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."

gr.Interface(
    fn = make_prediction,
    inputs = [gr.Audio(source="upload", type="filepath", optional=True), 
              gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
    outputs = [df, gr.Image(label="Visualisation")],
    theme = "huggingface",
    title = "BatDetect2 Demo",
    description = descr_txt,
    examples = examples,
    allow_flagging = 'never',
).launch()