Spaces:
Sleeping
Sleeping
Merge pull request #5 from Portiloop/milo/filetypes_and_staging
Browse files- portiloop/src/demo/offline.py +29 -6
- portiloop/src/demo/phase_demo.py +63 -0
- portiloop/src/demo/test_offline.py +12 -7
- portiloop/src/demo/utils.py +49 -12
- portiloop/src/stimulation.py +54 -6
portiloop/src/demo/offline.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
import numpy as np
|
3 |
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
4 |
-
|
5 |
from portiloop.src.processing import FilterPipeline
|
6 |
-
from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
7 |
import gradio as gr
|
8 |
|
9 |
|
10 |
-
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
11 |
# Get the options from the checkbox group
|
12 |
offline_filtering = 0 in detect_filter_opts
|
13 |
lacourse = 1 in detect_filter_opts
|
@@ -30,6 +29,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
30 |
# Read the xdf file to a numpy array
|
31 |
print("Loading xdf file...")
|
32 |
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
|
|
|
33 |
# Do the offline filtering of the data
|
34 |
if offline_filtering:
|
35 |
print("Filtering offline...")
|
@@ -39,13 +39,18 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
39 |
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
|
40 |
columns.append("offline_filtered_signal")
|
41 |
|
|
|
|
|
|
|
|
|
|
|
42 |
# Do Wamsley's method
|
43 |
if wamsley:
|
44 |
print("Running Wamsley detection...")
|
45 |
wamsley_data = offline_detect("Wamsley", \
|
46 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
47 |
data_whole[:, columns.index("time_stamps")],\
|
48 |
-
freq)
|
49 |
wamsley_data = np.expand_dims(wamsley_data, axis=1)
|
50 |
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
|
51 |
columns.append("wamsley_spindles")
|
@@ -56,7 +61,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
56 |
lacourse_data = offline_detect("Lacourse", \
|
57 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
58 |
data_whole[:, columns.index("time_stamps")],\
|
59 |
-
freq)
|
60 |
lacourse_data = np.expand_dims(lacourse_data, axis=1)
|
61 |
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
|
62 |
columns.append("lacourse_spindles")
|
@@ -72,12 +77,17 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
72 |
if online_detection:
|
73 |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
74 |
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
|
|
|
|
|
|
|
|
75 |
|
76 |
if online_filtering or online_detection:
|
77 |
print("Running online filtering and detection...")
|
78 |
|
79 |
points = []
|
80 |
online_activations = []
|
|
|
81 |
|
82 |
# Go through the data
|
83 |
for index, point in enumerate(data):
|
@@ -93,6 +103,13 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
93 |
# Detect the spindles
|
94 |
result = detector.detect([filtered_point])
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# Stimulate if necessary
|
97 |
stim = stimulator.stimulate(result)
|
98 |
if stim:
|
@@ -112,6 +129,12 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
112 |
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
113 |
columns.append("online_stimulations")
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
print("Saving output...")
|
116 |
# Output the data to a csv file
|
117 |
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
|
|
|
|
1 |
import numpy as np
|
2 |
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
3 |
+
from portiloop.src.stimulation import UpStateDelayer
|
4 |
from portiloop.src.processing import FilterPipeline
|
5 |
+
from portiloop.src.demo.utils import compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
6 |
import gradio as gr
|
7 |
|
8 |
|
9 |
+
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stimulation_phase="Fast", buffer_time=0.25):
|
10 |
# Get the options from the checkbox group
|
11 |
offline_filtering = 0 in detect_filter_opts
|
12 |
lacourse = 1 in detect_filter_opts
|
|
|
29 |
# Read the xdf file to a numpy array
|
30 |
print("Loading xdf file...")
|
31 |
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
|
32 |
+
|
33 |
# Do the offline filtering of the data
|
34 |
if offline_filtering:
|
35 |
print("Filtering offline...")
|
|
|
39 |
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
|
40 |
columns.append("offline_filtered_signal")
|
41 |
|
42 |
+
# Do the sleep staging approximation
|
43 |
+
if wamsley or lacourse:
|
44 |
+
print("Sleep staging...")
|
45 |
+
mask = sleep_stage(data_whole[:, columns.index("offline_filtered_signal")], threshold=150, group_size=100)
|
46 |
+
|
47 |
# Do Wamsley's method
|
48 |
if wamsley:
|
49 |
print("Running Wamsley detection...")
|
50 |
wamsley_data = offline_detect("Wamsley", \
|
51 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
52 |
data_whole[:, columns.index("time_stamps")],\
|
53 |
+
freq, mask)
|
54 |
wamsley_data = np.expand_dims(wamsley_data, axis=1)
|
55 |
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
|
56 |
columns.append("wamsley_spindles")
|
|
|
61 |
lacourse_data = offline_detect("Lacourse", \
|
62 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
63 |
data_whole[:, columns.index("time_stamps")],\
|
64 |
+
freq, mask)
|
65 |
lacourse_data = np.expand_dims(lacourse_data, axis=1)
|
66 |
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
|
67 |
columns.append("lacourse_spindles")
|
|
|
77 |
if online_detection:
|
78 |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
79 |
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
80 |
+
if stimulation_phase != "Fast":
|
81 |
+
stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
|
82 |
+
stimulator.add_delayer(stimulation_delayer)
|
83 |
+
|
84 |
|
85 |
if online_filtering or online_detection:
|
86 |
print("Running online filtering and detection...")
|
87 |
|
88 |
points = []
|
89 |
online_activations = []
|
90 |
+
delayed_stims = []
|
91 |
|
92 |
# Go through the data
|
93 |
for index, point in enumerate(data):
|
|
|
103 |
# Detect the spindles
|
104 |
result = detector.detect([filtered_point])
|
105 |
|
106 |
+
if stimulation_phase != "Fast":
|
107 |
+
delayed_stim = stimulation_delayer.step_timesteps(filtered_point[0])
|
108 |
+
if delayed_stim:
|
109 |
+
delayed_stims.append(1)
|
110 |
+
else:
|
111 |
+
delayed_stims.append(0)
|
112 |
+
|
113 |
# Stimulate if necessary
|
114 |
stim = stimulator.stimulate(result)
|
115 |
if stim:
|
|
|
129 |
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
130 |
columns.append("online_stimulations")
|
131 |
|
132 |
+
if stimulation_phase != "Fast":
|
133 |
+
delayed_stims = np.array(delayed_stims)
|
134 |
+
delayed_stims = np.expand_dims(delayed_stims, axis=1)
|
135 |
+
data_whole = np.concatenate((data_whole, delayed_stims), axis=1)
|
136 |
+
columns.append("delayed_stimulations")
|
137 |
+
|
138 |
print("Saving output...")
|
139 |
# Output the data to a csv file
|
140 |
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
portiloop/src/demo/phase_demo.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from portiloop.src.demo.offline import run_offline
|
4 |
+
|
5 |
+
|
6 |
+
def on_upload_file(file):
|
7 |
+
# Check if file extension is .xdf
|
8 |
+
if file.name.split(".")[-1] != "xdf":
|
9 |
+
raise gr.Error("Please upload a .xdf file.")
|
10 |
+
else:
|
11 |
+
return file.name
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
with gr.Blocks(title="Portiloop") as demo:
|
16 |
+
gr.Markdown("# Portiloop Demo")
|
17 |
+
gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
|
18 |
+
gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
|
19 |
+
|
20 |
+
with gr.Row():
|
21 |
+
xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
|
22 |
+
xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
|
23 |
+
|
24 |
+
xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
|
25 |
+
|
26 |
+
# Make a checkbox group for the options
|
27 |
+
detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
|
28 |
+
|
29 |
+
# Options for phase stimulation
|
30 |
+
with gr.Row():
|
31 |
+
# Dropwdown for phase
|
32 |
+
phase = gr.Dropdown(choices=["Peak", "Fast", "Valley"], value="Peak", label="Phase", interactive=True)
|
33 |
+
buffer_time = gr.Slider(0, 1, value=0.3, step=0.01, label="Buffer Time", interactive=True)
|
34 |
+
|
35 |
+
# Threshold value
|
36 |
+
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
|
37 |
+
# Detection Channel
|
38 |
+
detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
|
39 |
+
# Frequency
|
40 |
+
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
|
41 |
+
|
42 |
+
with gr.Row():
|
43 |
+
output_array = gr.File(label="Output CSV File")
|
44 |
+
output_table = gr.Markdown(label="Output Table")
|
45 |
+
|
46 |
+
run_inference = gr.Button(value="Run Inference")
|
47 |
+
run_inference.click(
|
48 |
+
fn=run_offline,
|
49 |
+
inputs=[
|
50 |
+
xdf_file_static,
|
51 |
+
detect_filter,
|
52 |
+
threshold,
|
53 |
+
detect_channel,
|
54 |
+
freq,
|
55 |
+
phase,
|
56 |
+
buffer_time],
|
57 |
+
outputs=[output_array, output_table])
|
58 |
+
|
59 |
+
demo.queue()
|
60 |
+
demo.launch(share=False)
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
portiloop/src/demo/test_offline.py
CHANGED
@@ -2,7 +2,9 @@ import itertools
|
|
2 |
import unittest
|
3 |
from portiloop.src.demo.offline import run_offline
|
4 |
from pathlib import Path
|
|
|
5 |
|
|
|
6 |
|
7 |
class TestOffline(unittest.TestCase):
|
8 |
|
@@ -21,7 +23,7 @@ class TestOffline(unittest.TestCase):
|
|
21 |
all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
|
22 |
all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
|
23 |
self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
|
24 |
-
self.xdf_file = Path(__file__).parents[3] / "
|
25 |
|
26 |
|
27 |
def test_all_options(self):
|
@@ -30,17 +32,20 @@ class TestOffline(unittest.TestCase):
|
|
30 |
self.assertTrue(config['online_filtering'])
|
31 |
|
32 |
def test_single_option(self):
|
|
|
|
|
|
|
|
|
33 |
res = list(run_offline(
|
34 |
self.xdf_file,
|
35 |
-
|
36 |
-
online_filtering=True,
|
37 |
-
online_detection=True,
|
38 |
-
wamsley=True,
|
39 |
-
lacourse=True,
|
40 |
threshold=0.5,
|
41 |
channel_num=2,
|
42 |
-
freq=250
|
|
|
|
|
43 |
print(res)
|
|
|
44 |
|
45 |
def tearDown(self):
|
46 |
pass
|
|
|
2 |
import unittest
|
3 |
from portiloop.src.demo.offline import run_offline
|
4 |
from pathlib import Path
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
|
7 |
+
from portiloop.src.demo.utils import sleep_stage, xdf2array
|
8 |
|
9 |
class TestOffline(unittest.TestCase):
|
10 |
|
|
|
23 |
all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
|
24 |
all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
|
25 |
self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
|
26 |
+
self.xdf_file = Path(__file__).parents[3] / "test_file.xdf"
|
27 |
|
28 |
|
29 |
def test_all_options(self):
|
|
|
32 |
self.assertTrue(config['online_filtering'])
|
33 |
|
34 |
def test_single_option(self):
|
35 |
+
|
36 |
+
# Test options correspond to an index in the possible checkbox group options
|
37 |
+
test_options = [0, 1, 2]
|
38 |
+
|
39 |
res = list(run_offline(
|
40 |
self.xdf_file,
|
41 |
+
test_options,
|
|
|
|
|
|
|
|
|
42 |
threshold=0.5,
|
43 |
channel_num=2,
|
44 |
+
freq=250,
|
45 |
+
stimulation_phase="Peak",
|
46 |
+
buffer_time=0.3))
|
47 |
print(res)
|
48 |
+
pass
|
49 |
|
50 |
def tearDown(self):
|
51 |
pass
|
portiloop/src/demo/utils.py
CHANGED
@@ -13,6 +13,32 @@ STREAM_NAMES = {
|
|
13 |
}
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
|
17 |
def __init__(self):
|
18 |
self.last_detected_ts = time.time()
|
@@ -87,15 +113,19 @@ def xdf2array(xdf_path, channel):
|
|
87 |
return np.array(csv_list), columns
|
88 |
|
89 |
|
90 |
-
def offline_detect(method, data, timesteps, freq):
|
|
|
|
|
|
|
91 |
# Get the spindle data from the offline methods
|
92 |
time = np.arange(0, len(data)) / freq
|
|
|
93 |
if method == "Lacourse":
|
94 |
detector = DetectSpindle(method='Lacourse2018')
|
95 |
-
spindles, _, _ = detect_Lacourse2018(
|
96 |
elif method == "Wamsley":
|
97 |
detector = DetectSpindle(method='Wamsley2012')
|
98 |
-
spindles, _, _ = detect_Wamsley2012(
|
99 |
else:
|
100 |
raise ValueError("Invalid method")
|
101 |
|
@@ -134,18 +164,25 @@ def offline_filter(signal, freq):
|
|
134 |
def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
|
135 |
# Count the number of spindles detected by each method
|
136 |
online_stimulation_count = np.sum(online_stimulation)
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
|
145 |
# Create markdown table with the results
|
146 |
table = "| Method | Detected spindles | Overlap with Portiloop |\n"
|
147 |
table += "| --- | --- | --- |\n"
|
148 |
table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
|
149 |
-
|
150 |
-
|
|
|
|
|
151 |
return table
|
|
|
|
13 |
}
|
14 |
|
15 |
|
16 |
+
def sleep_stage(data, threshold=150, group_size=2):
|
17 |
+
"""Sleep stage approximation using a threshold and a group size.
|
18 |
+
Returns a numpy array containing all indices in the input data which CAN be used for offline detection.
|
19 |
+
These indices can then be used to reconstruct the signal from the original data.
|
20 |
+
"""
|
21 |
+
# Find all indexes where the signal is above or below the threshold
|
22 |
+
above = np.where(data > threshold)
|
23 |
+
below = np.where(data < -threshold)
|
24 |
+
indices = np.concatenate((above, below), axis=1)[0]
|
25 |
+
|
26 |
+
indices = np.sort(indices)
|
27 |
+
# Get all the indices where the difference between two consecutive indices is larger than 100
|
28 |
+
groups = np.where(np.diff(indices) <= group_size)[0] + 1
|
29 |
+
# Get the important indices
|
30 |
+
important_indices = indices[groups]
|
31 |
+
# Get all the indices between the important indices
|
32 |
+
group_filler = [np.arange(indices[groups[n] - 1] + 1, index) for n, index in enumerate(important_indices)]
|
33 |
+
# Create flat array from fillers
|
34 |
+
group_filler = np.concatenate(group_filler)
|
35 |
+
# Append all group fillers to the indices
|
36 |
+
masked_indices = np.sort(np.concatenate((indices, group_filler)))
|
37 |
+
unmasked_indices = np.setdiff1d(np.arange(len(data)), masked_indices)
|
38 |
+
|
39 |
+
return unmasked_indices
|
40 |
+
|
41 |
+
|
42 |
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
|
43 |
def __init__(self):
|
44 |
self.last_detected_ts = time.time()
|
|
|
113 |
return np.array(csv_list), columns
|
114 |
|
115 |
|
116 |
+
def offline_detect(method, data, timesteps, freq, mask):
|
117 |
+
# Extract only the interesting elements from the mask
|
118 |
+
data_masked = data[mask]
|
119 |
+
|
120 |
# Get the spindle data from the offline methods
|
121 |
time = np.arange(0, len(data)) / freq
|
122 |
+
time_masked = time[mask]
|
123 |
if method == "Lacourse":
|
124 |
detector = DetectSpindle(method='Lacourse2018')
|
125 |
+
spindles, _, _ = detect_Lacourse2018(data_masked, freq, time_masked, detector)
|
126 |
elif method == "Wamsley":
|
127 |
detector = DetectSpindle(method='Wamsley2012')
|
128 |
+
spindles, _, _ = detect_Wamsley2012(data_masked, freq, time_masked, detector)
|
129 |
else:
|
130 |
raise ValueError("Invalid method")
|
131 |
|
|
|
164 |
def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
|
165 |
# Count the number of spindles detected by each method
|
166 |
online_stimulation_count = np.sum(online_stimulation)
|
167 |
+
if lacourse_spindles is not None:
|
168 |
+
lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
|
169 |
+
# Count how many spindles were detected by both online and lacourse
|
170 |
+
both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
|
171 |
+
|
172 |
+
if wamsley_spindles is not None:
|
173 |
+
wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
|
174 |
+
# Count how many spindles were detected by both online and wamsley
|
175 |
+
both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
|
176 |
+
|
177 |
+
|
178 |
|
179 |
# Create markdown table with the results
|
180 |
table = "| Method | Detected spindles | Overlap with Portiloop |\n"
|
181 |
table += "| --- | --- | --- |\n"
|
182 |
table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
|
183 |
+
if lacourse_spindles is not None:
|
184 |
+
table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
|
185 |
+
if wamsley_spindles is not None:
|
186 |
+
table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
|
187 |
return table
|
188 |
+
|
portiloop/src/stimulation.py
CHANGED
@@ -3,6 +3,8 @@ from enum import Enum
|
|
3 |
import time
|
4 |
from threading import Thread, Lock
|
5 |
from pathlib import Path
|
|
|
|
|
6 |
|
7 |
from portiloop.src import ADS
|
8 |
|
@@ -146,20 +148,18 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
146 |
|
147 |
# Class that delays stimulation to always stimulate peak or through
|
148 |
class UpStateDelayer:
|
149 |
-
def __init__(self, sample_freq,
|
150 |
'''
|
151 |
args:
|
152 |
sample_freq: int -> Sampling frequency of signal in Hz
|
153 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
154 |
'''
|
155 |
# Get number of timesteps for a whole spindle
|
156 |
-
self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
|
157 |
self.sample_freq = sample_freq
|
158 |
-
self.buffer_size = 1.5 * self.spindle_timesteps
|
159 |
self.peak = peak
|
160 |
self.buffer = []
|
161 |
self.time_to_buffer = time_to_buffer
|
162 |
-
self.stimulate =
|
163 |
|
164 |
self.state = States.NO_SPINDLE
|
165 |
|
@@ -192,10 +192,39 @@ class UpStateDelayer:
|
|
192 |
return True
|
193 |
return False
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
def detected(self):
|
196 |
if self.state == States.NO_SPINDLE:
|
197 |
self.state = States.BUFFERING
|
198 |
-
self.time_started = time.time()
|
199 |
|
200 |
def compute_time_to_wait(self):
|
201 |
"""
|
@@ -208,8 +237,27 @@ class UpStateDelayer:
|
|
208 |
# Returns the index of the last peak in the buffer
|
209 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
# Compute the time until next peak and return it
|
212 |
-
|
|
|
|
|
|
|
213 |
|
214 |
class States(Enum):
|
215 |
NO_SPINDLE = 0
|
|
|
3 |
import time
|
4 |
from threading import Thread, Lock
|
5 |
from pathlib import Path
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
|
9 |
from portiloop.src import ADS
|
10 |
|
|
|
148 |
|
149 |
# Class that delays stimulation to always stimulate peak or through
|
150 |
class UpStateDelayer:
|
151 |
+
def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None):
|
152 |
'''
|
153 |
args:
|
154 |
sample_freq: int -> Sampling frequency of signal in Hz
|
155 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
156 |
'''
|
157 |
# Get number of timesteps for a whole spindle
|
|
|
158 |
self.sample_freq = sample_freq
|
|
|
159 |
self.peak = peak
|
160 |
self.buffer = []
|
161 |
self.time_to_buffer = time_to_buffer
|
162 |
+
self.stimulate = stimulate
|
163 |
|
164 |
self.state = States.NO_SPINDLE
|
165 |
|
|
|
192 |
return True
|
193 |
return False
|
194 |
|
195 |
+
def step_timesteps(self, point):
|
196 |
+
'''
|
197 |
+
Step the delayer, ads a point to buffer if necessary.
|
198 |
+
Returns True if stimulation is actually done
|
199 |
+
'''
|
200 |
+
if self.state == States.NO_SPINDLE:
|
201 |
+
return False
|
202 |
+
elif self.state == States.BUFFERING:
|
203 |
+
self.buffer.append(point)
|
204 |
+
# If we are done buffering, move on to the waiting stage
|
205 |
+
if len(self.buffer) >= self.time_to_buffer * self.sample_freq:
|
206 |
+
# Compute the necessary time to wait
|
207 |
+
self.time_to_wait = self.compute_time_to_wait()
|
208 |
+
self.state = States.DELAYING
|
209 |
+
self.buffer = []
|
210 |
+
self.delaying_counter = 0
|
211 |
+
return False
|
212 |
+
elif self.state == States.DELAYING:
|
213 |
+
# Check if we are done delaying
|
214 |
+
self.delaying_counter += 1
|
215 |
+
if self.delaying_counter >= self.time_to_wait * self.sample_freq:
|
216 |
+
# Actually stimulate the patient after the delay
|
217 |
+
if self.stimulate is not None:
|
218 |
+
self.stimulate()
|
219 |
+
# Reset state
|
220 |
+
self.time_to_wait = -1
|
221 |
+
self.state = States.NO_SPINDLE
|
222 |
+
return True
|
223 |
+
return False
|
224 |
+
|
225 |
def detected(self):
|
226 |
if self.state == States.NO_SPINDLE:
|
227 |
self.state = States.BUFFERING
|
|
|
228 |
|
229 |
def compute_time_to_wait(self):
|
230 |
"""
|
|
|
237 |
# Returns the index of the last peak in the buffer
|
238 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
239 |
|
240 |
+
# Make a figure to show the peaks
|
241 |
+
if False:
|
242 |
+
plt.figure()
|
243 |
+
plt.plot(self.buffer)
|
244 |
+
for peak in peaks:
|
245 |
+
plt.axvline(x=peak)
|
246 |
+
plt.plot(np.zeros_like(self.buffer), "--", color="gray")
|
247 |
+
plt.show()
|
248 |
+
|
249 |
+
if len(peaks) == 0:
|
250 |
+
print("No peaks found, increase buffer size")
|
251 |
+
return (self.sample_freq / 10) * (1.0 / self.sample_freq)
|
252 |
+
|
253 |
+
# Compute average distance between each peak
|
254 |
+
avg_dist = np.mean(np.diff(peaks))
|
255 |
+
|
256 |
# Compute the time until next peak and return it
|
257 |
+
if (avg_dist < len(self.buffer) - peaks[-1]):
|
258 |
+
print("Average distance between peaks is smaller than the time to last peak, decrease buffer size")
|
259 |
+
return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq)
|
260 |
+
return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq)
|
261 |
|
262 |
class States(Enum):
|
263 |
NO_SPINDLE = 0
|