Spaces:
Running
Running
File size: 3,778 Bytes
2502b35 9ace58a 2502b35 9ace58a 2502b35 9ace58a 73fd754 9ace58a 73fd754 9ace58a 73fd754 296340f 9ace58a 7475d7b 9ace58a 73fd754 9ace58a 73fd754 9ace58a 73fd754 9ace58a 73fd754 296340f 9ace58a ab92695 9ace58a 296340f 9ace58a 296340f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import bat_detect.utils.detector_utils as du
import bat_detect.utils.audio_utils as au
import bat_detect.utils.plot_utils as viz
# setup the arguments
args = {}
args = du.get_default_bd_args()
args['detection_threshold'] = 0.3
args['time_expansion_factor'] = 1
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
max_duration = 2.0
# load the model
model, params = du.load_model(args['model_path'])
df = gr.Dataframe(
headers=["species", "time", "detection_prob", "species_prob"],
datatype=["str", "str", "str", "str"],
row_count=1,
col_count=(4, "fixed"),
label='Predictions'
)
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
def make_prediction(file_name=None, detection_threshold=0.3):
if file_name is not None:
audio_file = file_name
else:
return "You must provide an input audio file."
if detection_threshold is not None and detection_threshold != '':
args['detection_threshold'] = float(detection_threshold)
# process the file to generate predictions
results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
anns = [ann for ann in results['pred_dict']['annotation']]
clss = [aa['class'] for aa in anns]
st_time = [aa['start_time'] for aa in anns]
cls_prob = [aa['class_prob'] for aa in anns]
det_prob = [aa['det_prob'] for aa in anns]
data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
df = pd.DataFrame(data=data)
im = generate_results_image(audio_file, anns)
return [df, im]
def generate_results_image(audio_file, anns):
# load audio
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'],
params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
duration = audio.shape[0] / sampling_rate
# generate spec
spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
# create fig
plt.close('all')
fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
plt.ylabel('Freq - kHz')
plt.xlabel('Time - secs')
plt.tight_layout()
# convert fig to image
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
"<br>This model is only trained on bat species from the UK. If the input " \
"file is longer than 2 seconds, only the first 2 seconds will be processed." \
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
gr.Interface(
fn = make_prediction,
inputs = [gr.Audio(sources=["upload"], type="filepath"),
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
outputs = [df, gr.Image(label="Visualisation")],
theme = "huggingface",
title = "BatDetect2 Demo",
description = descr_txt,
examples = examples,
allow_flagging = 'never',
).launch()
|