Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import time | |
import numpy as np | |
import pandas as pd | |
from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator | |
from portiloop.src.detection import SleepSpindleRealTimeDetector | |
from portiloop.src.stimulation import UpStateDelayer | |
plt.switch_backend('agg') | |
from portiloop.src.processing import FilterPipeline | |
def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer): | |
# Read the csv file to a numpy array | |
data_whole = np.loadtxt(csv_file.name, delimiter=',') | |
# Get the data from the selected channel | |
detect_channel = int(detect_channel) | |
freq = int(freq) | |
data = data_whole[:, detect_channel - 1] | |
# Create the detector and the stimulator | |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel | |
stimulator = DemoSleepSpindleRealTimeStimulator() | |
if spindle_detection_mode != 'Fast': | |
delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer) | |
stimulator.add_delayer(delayer) | |
# Create the filtering pipeline | |
if filtering: | |
filter = FilterPipeline(nb_channels=1, sampling_rate=freq) | |
# Plotting variables | |
points = [] | |
activations = [] | |
delayed_activations = [] | |
# Go through the data | |
for index, point in enumerate(data): | |
# Step the delayer if exists | |
if spindle_detection_mode != 'Fast': | |
delayed = delayer.step(point) | |
if delayed: | |
delayed_activations.append(1) | |
else: | |
delayed_activations.append(0) | |
# Filter the data | |
if filtering: | |
filtered_point = filter.filter(np.array([point])) | |
else: | |
filtered_point = point | |
filtered_point = filtered_point.tolist() | |
# Detect the spindles | |
result = detector.detect([filtered_point]) | |
# Stimulate if necessary | |
stim = stimulator.stimulate(result) | |
if stim: | |
activations.append(1) | |
else: | |
activations.append(0) | |
# Add data to plotting buffer | |
points.append(filtered_point[0]) | |
# Plot the data | |
if index % (10 * freq) == 0: | |
plt.close() | |
fig = plt.figure(figsize=(20, 10)) | |
plt.clf() | |
plt.plot(points[-10 * freq:], label="Data") | |
# Draw vertical lines for activations | |
for index in get_activations(activations[-10 * freq:]): | |
plt.axvline(x=index, color='r', label="Fast Stimulation") | |
if spindle_detection_mode != 'Fast': | |
for index in get_activations(delayed_activations[-10 * freq:]): | |
plt.axvline(x=index, color='g', label="Delayed Stimulation") | |
yield fig, None | |
# Put all points and activations back in numpy arrays | |
points = np.array(points) | |
activations = np.array(activations) | |
delayed_activations = np.array(delayed_activations) | |
# Concatenate with the original data | |
data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1) | |
# Output the data to a csv file | |
np.savetxt('output.csv', data_whole, delimiter=',') | |
yield None, "output.csv" | |
# Function to return a list of all indexes where activations have happened | |
def get_activations(activations): | |
return [i for i, x in enumerate(activations) if x == 1] | |
with gr.Blocks() as demo: | |
gr.Markdown("Enter your csv file and click **Run Inference** to get the output.") | |
# Row containing all inputs: | |
with gr.Row(): | |
# CSV file | |
csv_file = gr.UploadButton(label="CSV File", file_count="single") | |
# Filtering (Boolean) | |
filtering = gr.Checkbox(label="Filtering (On/Off)", value=True) | |
# Threshold value | |
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True) | |
# Detection Channel | |
detect_column = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], value="1", label="Detection Column", interactive=True) | |
# Frequency | |
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Frequency", interactive=True) | |
# Spindle Frequency | |
spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency", interactive=True) | |
# Spindle Detection Mode | |
spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True) | |
# Time to buffer | |
time_to_buffer = gr.Slider(0, 1, value=0, step=0.01, label="Time to Buffer", interactive=True) | |
# Output plot | |
output_plot = gr.Plot() | |
# Output file | |
output_array = gr.File(label="Output CSV File") | |
# Row containing all buttons: | |
with gr.Row(): | |
# Run inference button | |
run_inference = gr.Button(value="Run Inference") | |
# Reset button | |
reset = gr.Button(value="Reset", variant="secondary") | |
run_inference.click(fn=do_treatment, inputs=[csv_file, filtering, threshold, detect_column, freq, spindle_freq, spindle_detection_mode, time_to_buffer], outputs=[output_plot, output_array]) | |
demo.queue() | |
demo.launch() | |