File size: 3,778 Bytes
2502b35
9ace58a
 
 
 
2502b35
9ace58a
 
 
2502b35
 
9ace58a
 
 
 
 
 
73fd754
9ace58a
 
 
 
 
 
73fd754
 
9ace58a
73fd754
296340f
9ace58a
 
 
 
 
 
 
 
 
 
 
 
 
 
7475d7b
9ace58a
 
73fd754
 
 
 
 
 
 
 
 
9ace58a
 
73fd754
 
 
9ace58a
73fd754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ace58a
 
 
 
73fd754
296340f
9ace58a
 
 
ab92695
9ace58a
296340f
9ace58a
 
 
 
 
 
296340f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import bat_detect.utils.detector_utils as du
import bat_detect.utils.audio_utils as au
import bat_detect.utils.plot_utils as viz


# setup the arguments
args = {}
args = du.get_default_bd_args()
args['detection_threshold'] = 0.3
args['time_expansion_factor'] = 1
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
max_duration = 2.0

# load the model
model, params = du.load_model(args['model_path'])


df =  gr.Dataframe(
        headers=["species", "time", "detection_prob", "species_prob"],
        datatype=["str", "str", "str", "str"],
        row_count=1,
        col_count=(4, "fixed"),
        label='Predictions'
    )
    
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
            ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
            ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]


def make_prediction(file_name=None, detection_threshold=0.3):
    
    if file_name is not None:
        audio_file = file_name
    else:
        return "You must provide an input audio file."
    
    if detection_threshold is not None and detection_threshold != '':
        args['detection_threshold'] = float(detection_threshold)
    
    # process the file to generate predictions
    results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
    
    anns = [ann for ann in results['pred_dict']['annotation']]
    clss = [aa['class'] for aa in anns]
    st_time = [aa['start_time'] for aa in anns]
    cls_prob = [aa['class_prob'] for aa in anns]
    det_prob = [aa['det_prob'] for aa in anns]
    data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
    
    df = pd.DataFrame(data=data)
    im = generate_results_image(audio_file, anns)
    
    return [df, im]


def generate_results_image(audio_file, anns): 
    
    # load audio
    sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], 
                           params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
    duration = audio.shape[0] / sampling_rate
        
    # generate spec
    spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)

    # create fig
    plt.close('all')
    fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
    spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
    viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
    plt.ylabel('Freq - kHz')
    plt.xlabel('Time - secs')
    plt.tight_layout()
    
    # convert fig to image
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))

    return im


descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
          "<br>This model is only trained on bat species from the UK. If the input " \
          "file is longer than 2 seconds, only the first 2 seconds will be processed." \
          "<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."

gr.Interface(
    fn = make_prediction,
    inputs = [gr.Audio(sources=["upload"], type="filepath"), 
              gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
    outputs = [df, gr.Image(label="Visualisation")],
    theme = "huggingface",
    title = "BatDetect2 Demo",
    description = descr_txt,
    examples = examples,
    allow_flagging = 'never',
).launch()