Spaces:
Build error
Build error
File size: 5,978 Bytes
7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 7d40d1a b0f3e42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import matplotlib.pyplot as plt
import time
import numpy as np
import pandas as pd
from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator
from portiloop.src.detection import SleepSpindleRealTimeDetector
from portiloop.src.stimulation import UpStateDelayer
plt.switch_backend('agg')
from portiloop.src.processing import FilterPipeline
def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer):
# Read the csv file to a numpy array
data_whole = np.loadtxt(csv_file.name, delimiter=',')
# Get the data from the selected channel
detect_channel = int(detect_channel)
freq = int(freq)
data = data_whole[:, detect_channel - 1]
# Create the detector and the stimulator
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
stimulator = DemoSleepSpindleRealTimeStimulator()
if spindle_detection_mode != 'Fast':
delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer)
stimulator.add_delayer(delayer)
# Create the filtering pipeline
if filtering:
filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
# Plotting variables
points = []
activations = []
delayed_activations = []
# Go through the data
for index, point in enumerate(data):
# Step the delayer if exists
if spindle_detection_mode != 'Fast':
delayed = delayer.step(point)
if delayed:
delayed_activations.append(1)
else:
delayed_activations.append(0)
# Filter the data
if filtering:
filtered_point = filter.filter(np.array([point]))
else:
filtered_point = point
filtered_point = filtered_point.tolist()
# Detect the spindles
result = detector.detect([filtered_point])
# Stimulate if necessary
stim = stimulator.stimulate(result)
if stim:
activations.append(1)
else:
activations.append(0)
# Add data to plotting buffer
points.append(filtered_point[0])
# Function to return a list of all indexes where activations have happened
def get_activations(activations):
return [i for i, x in enumerate(activations) if x == 1]
# Plot the data
if index % (10 * freq) == 0 and index >= (10 * freq):
plt.close()
fig = plt.figure(figsize=(20, 10))
plt.clf()
plt.plot(np.linspace(0, 10, num=freq*10), points[-10 * freq:], label="Data")
# Draw vertical lines for activations
for index in get_activations(activations[-10 * freq:]):
plt.axvline(x=index / freq, color='r', label="Fast Stimulation")
if spindle_detection_mode != 'Fast':
for index in get_activations(delayed_activations[-10 * freq:]):
plt.axvline(x=index / freq, color='g', label="Delayed Stimulation")
# Add axis titles and legend
plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
yield fig, None
# Put all points and activations back in numpy arrays
points = np.array(points)
activations = np.array(activations)
delayed_activations = np.array(delayed_activations)
# Concatenate with the original data
data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1)
# Output the data to a csv file
np.savetxt('output.csv', data_whole, delimiter=',')
yield None, "output.csv"
with gr.Blocks() as demo:
gr.Markdown("# Portiloop Demo")
gr.Markdown("This Demo takes as input a csv file containing EEG data and outputs a csv file with the following added: \n * The data filtered by the Portiloop online filter \n * The stimulations made by Portiloop.")
gr.Markdown("Upload your CSV file and click **Run Inference** to start the processing...")
# Row containing all inputs:
with gr.Row():
# CSV file
csv_file = gr.UploadButton(label="CSV File", file_count="single")
# Filtering (Boolean)
filtering = gr.Checkbox(label="Filtering (On/Off)", value=True)
# Threshold value
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
# Detection Channel
detect_column = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], value="1", label="Detection Column in CSV", interactive=True)
# Frequency
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
# Spindle Frequency
spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency (Hz)", interactive=True)
# Spindle Detection Mode
spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True)
# Time to buffer
time_to_buffer = gr.Slider(0, 1, value=0.3, step=0.01, label="Time to Buffer (s)", interactive=True)
# Output plot
output_plot = gr.Plot()
# Output file
output_array = gr.File(label="Output CSV File")
# Row containing all buttons:
with gr.Row():
# Run inference button
run_inference = gr.Button(value="Run Inference")
# Reset button
reset = gr.Button(value="Reset", variant="secondary")
run_inference.click(fn=do_treatment, inputs=[csv_file, filtering, threshold, detect_column, freq, spindle_freq, spindle_detection_mode, time_to_buffer], outputs=[output_plot, output_array])
demo.queue()
demo.launch(share=True)
|