Spaces:
Sleeping
Sleeping
Milo Sobral
commited on
Commit
·
986653d
1
Parent(s):
d4bce82
Added the Phase demo
Browse files- portiloop/src/demo/offline.py +20 -3
- portiloop/src/demo/phase_demo.py +63 -0
- portiloop/src/demo/test_offline.py +9 -7
- portiloop/src/demo/utils.py +15 -9
- portiloop/src/stimulation.py +54 -6
portiloop/src/demo/offline.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
import numpy as np
|
3 |
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
4 |
-
|
5 |
from portiloop.src.processing import FilterPipeline
|
6 |
from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
7 |
import gradio as gr
|
8 |
|
9 |
|
10 |
-
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
11 |
# Get the options from the checkbox group
|
12 |
offline_filtering = 0 in detect_filter_opts
|
13 |
lacourse = 1 in detect_filter_opts
|
@@ -72,12 +71,17 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
72 |
if online_detection:
|
73 |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
74 |
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
|
|
|
|
|
|
|
|
75 |
|
76 |
if online_filtering or online_detection:
|
77 |
print("Running online filtering and detection...")
|
78 |
|
79 |
points = []
|
80 |
online_activations = []
|
|
|
81 |
|
82 |
# Go through the data
|
83 |
for index, point in enumerate(data):
|
@@ -93,6 +97,13 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
93 |
# Detect the spindles
|
94 |
result = detector.detect([filtered_point])
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# Stimulate if necessary
|
97 |
stim = stimulator.stimulate(result)
|
98 |
if stim:
|
@@ -112,6 +123,12 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
112 |
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
113 |
columns.append("online_stimulations")
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
print("Saving output...")
|
116 |
# Output the data to a csv file
|
117 |
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
|
|
|
|
1 |
import numpy as np
|
2 |
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
3 |
+
from portiloop.src.stimulation import UpStateDelayer
|
4 |
from portiloop.src.processing import FilterPipeline
|
5 |
from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
6 |
import gradio as gr
|
7 |
|
8 |
|
9 |
+
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stimulation_phase="Fast", buffer_time=0.25):
|
10 |
# Get the options from the checkbox group
|
11 |
offline_filtering = 0 in detect_filter_opts
|
12 |
lacourse = 1 in detect_filter_opts
|
|
|
71 |
if online_detection:
|
72 |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
73 |
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
74 |
+
if stimulation_phase != "Fast":
|
75 |
+
stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
|
76 |
+
stimulator.add_delayer(stimulation_delayer)
|
77 |
+
|
78 |
|
79 |
if online_filtering or online_detection:
|
80 |
print("Running online filtering and detection...")
|
81 |
|
82 |
points = []
|
83 |
online_activations = []
|
84 |
+
delayed_stims = []
|
85 |
|
86 |
# Go through the data
|
87 |
for index, point in enumerate(data):
|
|
|
97 |
# Detect the spindles
|
98 |
result = detector.detect([filtered_point])
|
99 |
|
100 |
+
if stimulation_phase != "Fast":
|
101 |
+
delayed_stim = stimulation_delayer.step_timesteps(filtered_point[0])
|
102 |
+
if delayed_stim:
|
103 |
+
delayed_stims.append(1)
|
104 |
+
else:
|
105 |
+
delayed_stims.append(0)
|
106 |
+
|
107 |
# Stimulate if necessary
|
108 |
stim = stimulator.stimulate(result)
|
109 |
if stim:
|
|
|
123 |
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
124 |
columns.append("online_stimulations")
|
125 |
|
126 |
+
if stimulation_phase != "Fast":
|
127 |
+
delayed_stims = np.array(delayed_stims)
|
128 |
+
delayed_stims = np.expand_dims(delayed_stims, axis=1)
|
129 |
+
data_whole = np.concatenate((data_whole, delayed_stims), axis=1)
|
130 |
+
columns.append("delayed_stimulations")
|
131 |
+
|
132 |
print("Saving output...")
|
133 |
# Output the data to a csv file
|
134 |
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
portiloop/src/demo/phase_demo.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from portiloop.src.demo.offline import run_offline
|
4 |
+
|
5 |
+
|
6 |
+
def on_upload_file(file):
|
7 |
+
# Check if file extension is .xdf
|
8 |
+
if file.name.split(".")[-1] != "xdf":
|
9 |
+
raise gr.Error("Please upload a .xdf file.")
|
10 |
+
else:
|
11 |
+
return file.name
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
with gr.Blocks(title="Portiloop") as demo:
|
16 |
+
gr.Markdown("# Portiloop Demo")
|
17 |
+
gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
|
18 |
+
gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
|
19 |
+
|
20 |
+
with gr.Row():
|
21 |
+
xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
|
22 |
+
xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
|
23 |
+
|
24 |
+
xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
|
25 |
+
|
26 |
+
# Make a checkbox group for the options
|
27 |
+
detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
|
28 |
+
|
29 |
+
# Options for phase stimulation
|
30 |
+
with gr.Row():
|
31 |
+
# Dropwdown for phase
|
32 |
+
phase = gr.Dropdown(choices=["Peak", "Fast", "Valley"], value="Peak", label="Phase", interactive=True)
|
33 |
+
buffer_time = gr.Slider(0, 1, value=0.3, step=0.01, label="Buffer Time", interactive=True)
|
34 |
+
|
35 |
+
# Threshold value
|
36 |
+
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
|
37 |
+
# Detection Channel
|
38 |
+
detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
|
39 |
+
# Frequency
|
40 |
+
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
|
41 |
+
|
42 |
+
with gr.Row():
|
43 |
+
output_array = gr.File(label="Output CSV File")
|
44 |
+
output_table = gr.Markdown(label="Output Table")
|
45 |
+
|
46 |
+
run_inference = gr.Button(value="Run Inference")
|
47 |
+
run_inference.click(
|
48 |
+
fn=run_offline,
|
49 |
+
inputs=[
|
50 |
+
xdf_file_static,
|
51 |
+
detect_filter,
|
52 |
+
threshold,
|
53 |
+
detect_channel,
|
54 |
+
freq,
|
55 |
+
phase,
|
56 |
+
buffer_time],
|
57 |
+
outputs=[output_array, output_table])
|
58 |
+
|
59 |
+
demo.queue()
|
60 |
+
demo.launch(share=False)
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
portiloop/src/demo/test_offline.py
CHANGED
@@ -2,7 +2,7 @@ import itertools
|
|
2 |
import unittest
|
3 |
from portiloop.src.demo.offline import run_offline
|
4 |
from pathlib import Path
|
5 |
-
|
6 |
|
7 |
class TestOffline(unittest.TestCase):
|
8 |
|
@@ -30,16 +30,18 @@ class TestOffline(unittest.TestCase):
|
|
30 |
self.assertTrue(config['online_filtering'])
|
31 |
|
32 |
def test_single_option(self):
|
|
|
|
|
|
|
|
|
33 |
res = list(run_offline(
|
34 |
self.xdf_file,
|
35 |
-
|
36 |
-
online_filtering=True,
|
37 |
-
online_detection=True,
|
38 |
-
wamsley=True,
|
39 |
-
lacourse=True,
|
40 |
threshold=0.5,
|
41 |
channel_num=2,
|
42 |
-
freq=250
|
|
|
|
|
43 |
print(res)
|
44 |
|
45 |
def tearDown(self):
|
|
|
2 |
import unittest
|
3 |
from portiloop.src.demo.offline import run_offline
|
4 |
from pathlib import Path
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
|
7 |
class TestOffline(unittest.TestCase):
|
8 |
|
|
|
30 |
self.assertTrue(config['online_filtering'])
|
31 |
|
32 |
def test_single_option(self):
|
33 |
+
|
34 |
+
# Test options correspond to an index in the possible checkbox group options
|
35 |
+
test_options = [3, 4]
|
36 |
+
|
37 |
res = list(run_offline(
|
38 |
self.xdf_file,
|
39 |
+
test_options,
|
|
|
|
|
|
|
|
|
40 |
threshold=0.5,
|
41 |
channel_num=2,
|
42 |
+
freq=250,
|
43 |
+
stimulation_phase="Peak",
|
44 |
+
buffer_time=0.3))
|
45 |
print(res)
|
46 |
|
47 |
def tearDown(self):
|
portiloop/src/demo/utils.py
CHANGED
@@ -134,18 +134,24 @@ def offline_filter(signal, freq):
|
|
134 |
def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
|
135 |
# Count the number of spindles detected by each method
|
136 |
online_stimulation_count = np.sum(online_stimulation)
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
|
145 |
# Create markdown table with the results
|
146 |
table = "| Method | Detected spindles | Overlap with Portiloop |\n"
|
147 |
table += "| --- | --- | --- |\n"
|
148 |
table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
|
149 |
-
|
150 |
-
|
|
|
|
|
151 |
return table
|
|
|
134 |
def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
|
135 |
# Count the number of spindles detected by each method
|
136 |
online_stimulation_count = np.sum(online_stimulation)
|
137 |
+
if lacourse_spindles is not None:
|
138 |
+
lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
|
139 |
+
# Count how many spindles were detected by both online and lacourse
|
140 |
+
both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
|
141 |
+
|
142 |
+
if wamsley_spindles is not None:
|
143 |
+
wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
|
144 |
+
# Count how many spindles were detected by both online and wamsley
|
145 |
+
both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
|
146 |
+
|
147 |
+
|
148 |
|
149 |
# Create markdown table with the results
|
150 |
table = "| Method | Detected spindles | Overlap with Portiloop |\n"
|
151 |
table += "| --- | --- | --- |\n"
|
152 |
table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
|
153 |
+
if lacourse_spindles is not None:
|
154 |
+
table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
|
155 |
+
if wamsley_spindles is not None:
|
156 |
+
table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
|
157 |
return table
|
portiloop/src/stimulation.py
CHANGED
@@ -3,6 +3,8 @@ from enum import Enum
|
|
3 |
import time
|
4 |
from threading import Thread, Lock
|
5 |
from pathlib import Path
|
|
|
|
|
6 |
|
7 |
from portiloop.src import ADS
|
8 |
|
@@ -146,20 +148,18 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
146 |
|
147 |
# Class that delays stimulation to always stimulate peak or through
|
148 |
class UpStateDelayer:
|
149 |
-
def __init__(self, sample_freq,
|
150 |
'''
|
151 |
args:
|
152 |
sample_freq: int -> Sampling frequency of signal in Hz
|
153 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
154 |
'''
|
155 |
# Get number of timesteps for a whole spindle
|
156 |
-
self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
|
157 |
self.sample_freq = sample_freq
|
158 |
-
self.buffer_size = 1.5 * self.spindle_timesteps
|
159 |
self.peak = peak
|
160 |
self.buffer = []
|
161 |
self.time_to_buffer = time_to_buffer
|
162 |
-
self.stimulate =
|
163 |
|
164 |
self.state = States.NO_SPINDLE
|
165 |
|
@@ -192,10 +192,39 @@ class UpStateDelayer:
|
|
192 |
return True
|
193 |
return False
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
def detected(self):
|
196 |
if self.state == States.NO_SPINDLE:
|
197 |
self.state = States.BUFFERING
|
198 |
-
self.time_started = time.time()
|
199 |
|
200 |
def compute_time_to_wait(self):
|
201 |
"""
|
@@ -208,8 +237,27 @@ class UpStateDelayer:
|
|
208 |
# Returns the index of the last peak in the buffer
|
209 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
# Compute the time until next peak and return it
|
212 |
-
|
|
|
|
|
|
|
213 |
|
214 |
class States(Enum):
|
215 |
NO_SPINDLE = 0
|
|
|
3 |
import time
|
4 |
from threading import Thread, Lock
|
5 |
from pathlib import Path
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
|
9 |
from portiloop.src import ADS
|
10 |
|
|
|
148 |
|
149 |
# Class that delays stimulation to always stimulate peak or through
|
150 |
class UpStateDelayer:
|
151 |
+
def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None):
|
152 |
'''
|
153 |
args:
|
154 |
sample_freq: int -> Sampling frequency of signal in Hz
|
155 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
156 |
'''
|
157 |
# Get number of timesteps for a whole spindle
|
|
|
158 |
self.sample_freq = sample_freq
|
|
|
159 |
self.peak = peak
|
160 |
self.buffer = []
|
161 |
self.time_to_buffer = time_to_buffer
|
162 |
+
self.stimulate = stimulate
|
163 |
|
164 |
self.state = States.NO_SPINDLE
|
165 |
|
|
|
192 |
return True
|
193 |
return False
|
194 |
|
195 |
+
def step_timesteps(self, point):
|
196 |
+
'''
|
197 |
+
Step the delayer, ads a point to buffer if necessary.
|
198 |
+
Returns True if stimulation is actually done
|
199 |
+
'''
|
200 |
+
if self.state == States.NO_SPINDLE:
|
201 |
+
return False
|
202 |
+
elif self.state == States.BUFFERING:
|
203 |
+
self.buffer.append(point)
|
204 |
+
# If we are done buffering, move on to the waiting stage
|
205 |
+
if len(self.buffer) >= self.time_to_buffer * self.sample_freq:
|
206 |
+
# Compute the necessary time to wait
|
207 |
+
self.time_to_wait = self.compute_time_to_wait()
|
208 |
+
self.state = States.DELAYING
|
209 |
+
self.buffer = []
|
210 |
+
self.delaying_counter = 0
|
211 |
+
return False
|
212 |
+
elif self.state == States.DELAYING:
|
213 |
+
# Check if we are done delaying
|
214 |
+
self.delaying_counter += 1
|
215 |
+
if self.delaying_counter >= self.time_to_wait * self.sample_freq:
|
216 |
+
# Actually stimulate the patient after the delay
|
217 |
+
if self.stimulate is not None:
|
218 |
+
self.stimulate()
|
219 |
+
# Reset state
|
220 |
+
self.time_to_wait = -1
|
221 |
+
self.state = States.NO_SPINDLE
|
222 |
+
return True
|
223 |
+
return False
|
224 |
+
|
225 |
def detected(self):
|
226 |
if self.state == States.NO_SPINDLE:
|
227 |
self.state = States.BUFFERING
|
|
|
228 |
|
229 |
def compute_time_to_wait(self):
|
230 |
"""
|
|
|
237 |
# Returns the index of the last peak in the buffer
|
238 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
239 |
|
240 |
+
# Make a figure to show the peaks
|
241 |
+
if False:
|
242 |
+
plt.figure()
|
243 |
+
plt.plot(self.buffer)
|
244 |
+
for peak in peaks:
|
245 |
+
plt.axvline(x=peak)
|
246 |
+
plt.plot(np.zeros_like(self.buffer), "--", color="gray")
|
247 |
+
plt.show()
|
248 |
+
|
249 |
+
if len(peaks) == 0:
|
250 |
+
print("No peaks found, increase buffer size")
|
251 |
+
return (self.sample_freq / 10) * (1.0 / self.sample_freq)
|
252 |
+
|
253 |
+
# Compute average distance between each peak
|
254 |
+
avg_dist = np.mean(np.diff(peaks))
|
255 |
+
|
256 |
# Compute the time until next peak and return it
|
257 |
+
if (avg_dist < len(self.buffer) - peaks[-1]):
|
258 |
+
print("Average distance between peaks is smaller than the time to last peak, decrease buffer size")
|
259 |
+
return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq)
|
260 |
+
return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq)
|
261 |
|
262 |
class States(Enum):
|
263 |
NO_SPINDLE = 0
|