Milo Sobral
Finished setting up the gradio demo
7d40d1a
raw
history blame
5.44 kB
import gradio as gr
import matplotlib.pyplot as plt
import time
import numpy as np
import pandas as pd
from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator
from portiloop.src.detection import SleepSpindleRealTimeDetector
from portiloop.src.stimulation import UpStateDelayer
plt.switch_backend('agg')
from portiloop.src.processing import FilterPipeline
def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer):
# Read the csv file to a numpy array
data_whole = np.loadtxt(csv_file.name, delimiter=',')
# Get the data from the selected channel
detect_channel = int(detect_channel)
freq = int(freq)
data = data_whole[:, detect_channel - 1]
# Create the detector and the stimulator
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
stimulator = DemoSleepSpindleRealTimeStimulator()
if spindle_detection_mode != 'Fast':
delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer)
stimulator.add_delayer(delayer)
# Create the filtering pipeline
if filtering:
filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
# Plotting variables
points = []
activations = []
delayed_activations = []
# Go through the data
for index, point in enumerate(data):
# Step the delayer if exists
if spindle_detection_mode != 'Fast':
delayed = delayer.step(point)
if delayed:
delayed_activations.append(1)
else:
delayed_activations.append(0)
# Filter the data
if filtering:
filtered_point = filter.filter(np.array([point]))
else:
filtered_point = point
filtered_point = filtered_point.tolist()
# Detect the spindles
result = detector.detect([filtered_point])
# Stimulate if necessary
stim = stimulator.stimulate(result)
if stim:
activations.append(1)
else:
activations.append(0)
# Add data to plotting buffer
points.append(filtered_point[0])
# Plot the data
if index % (10 * freq) == 0:
plt.close()
fig = plt.figure(figsize=(20, 10))
plt.clf()
plt.plot(points[-10 * freq:], label="Data")
# Draw vertical lines for activations
for index in get_activations(activations[-10 * freq:]):
plt.axvline(x=index, color='r', label="Fast Stimulation")
if spindle_detection_mode != 'Fast':
for index in get_activations(delayed_activations[-10 * freq:]):
plt.axvline(x=index, color='g', label="Delayed Stimulation")
yield fig, None
# Put all points and activations back in numpy arrays
points = np.array(points)
activations = np.array(activations)
delayed_activations = np.array(delayed_activations)
# Concatenate with the original data
data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1)
# Output the data to a csv file
np.savetxt('output.csv', data_whole, delimiter=',')
yield None, "output.csv"
# Function to return a list of all indexes where activations have happened
def get_activations(activations):
return [i for i, x in enumerate(activations) if x == 1]
with gr.Blocks() as demo:
gr.Markdown("Enter your csv file and click **Run Inference** to get the output.")
# Row containing all inputs:
with gr.Row():
# CSV file
csv_file = gr.UploadButton(label="CSV File", file_count="single")
# Filtering (Boolean)
filtering = gr.Checkbox(label="Filtering (On/Off)", value=True)
# Threshold value
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
# Detection Channel
detect_column = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], value="1", label="Detection Column", interactive=True)
# Frequency
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Frequency", interactive=True)
# Spindle Frequency
spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency", interactive=True)
# Spindle Detection Mode
spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True)
# Time to buffer
time_to_buffer = gr.Slider(0, 1, value=0, step=0.01, label="Time to Buffer", interactive=True)
# Output plot
output_plot = gr.Plot()
# Output file
output_array = gr.File(label="Output CSV File")
# Row containing all buttons:
with gr.Row():
# Run inference button
run_inference = gr.Button(value="Run Inference")
# Reset button
reset = gr.Button(value="Reset", variant="secondary")
run_inference.click(fn=do_treatment, inputs=[csv_file, filtering, threshold, detect_column, freq, spindle_freq, spindle_detection_mode, time_to_buffer], outputs=[output_plot, output_array])
demo.queue()
demo.launch()