Milo Sobral commited on
Commit
986653d
·
1 Parent(s): d4bce82

Added the Phase demo

Browse files
portiloop/src/demo/offline.py CHANGED
@@ -1,13 +1,12 @@
1
- import matplotlib.pyplot as plt
2
  import numpy as np
3
  from portiloop.src.detection import SleepSpindleRealTimeDetector
4
- plt.switch_backend('agg')
5
  from portiloop.src.processing import FilterPipeline
6
  from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
7
  import gradio as gr
8
 
9
 
10
- def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
11
  # Get the options from the checkbox group
12
  offline_filtering = 0 in detect_filter_opts
13
  lacourse = 1 in detect_filter_opts
@@ -72,12 +71,17 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
72
  if online_detection:
73
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
74
  stimulator = OfflineSleepSpindleRealTimeStimulator()
 
 
 
 
75
 
76
  if online_filtering or online_detection:
77
  print("Running online filtering and detection...")
78
 
79
  points = []
80
  online_activations = []
 
81
 
82
  # Go through the data
83
  for index, point in enumerate(data):
@@ -93,6 +97,13 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
93
  # Detect the spindles
94
  result = detector.detect([filtered_point])
95
 
 
 
 
 
 
 
 
96
  # Stimulate if necessary
97
  stim = stimulator.stimulate(result)
98
  if stim:
@@ -112,6 +123,12 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
112
  data_whole = np.concatenate((data_whole, online_activations), axis=1)
113
  columns.append("online_stimulations")
114
 
 
 
 
 
 
 
115
  print("Saving output...")
116
  # Output the data to a csv file
117
  np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
 
 
1
  import numpy as np
2
  from portiloop.src.detection import SleepSpindleRealTimeDetector
3
+ from portiloop.src.stimulation import UpStateDelayer
4
  from portiloop.src.processing import FilterPipeline
5
  from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
6
  import gradio as gr
7
 
8
 
9
+ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stimulation_phase="Fast", buffer_time=0.25):
10
  # Get the options from the checkbox group
11
  offline_filtering = 0 in detect_filter_opts
12
  lacourse = 1 in detect_filter_opts
 
71
  if online_detection:
72
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
73
  stimulator = OfflineSleepSpindleRealTimeStimulator()
74
+ if stimulation_phase != "Fast":
75
+ stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
76
+ stimulator.add_delayer(stimulation_delayer)
77
+
78
 
79
  if online_filtering or online_detection:
80
  print("Running online filtering and detection...")
81
 
82
  points = []
83
  online_activations = []
84
+ delayed_stims = []
85
 
86
  # Go through the data
87
  for index, point in enumerate(data):
 
97
  # Detect the spindles
98
  result = detector.detect([filtered_point])
99
 
100
+ if stimulation_phase != "Fast":
101
+ delayed_stim = stimulation_delayer.step_timesteps(filtered_point[0])
102
+ if delayed_stim:
103
+ delayed_stims.append(1)
104
+ else:
105
+ delayed_stims.append(0)
106
+
107
  # Stimulate if necessary
108
  stim = stimulator.stimulate(result)
109
  if stim:
 
123
  data_whole = np.concatenate((data_whole, online_activations), axis=1)
124
  columns.append("online_stimulations")
125
 
126
+ if stimulation_phase != "Fast":
127
+ delayed_stims = np.array(delayed_stims)
128
+ delayed_stims = np.expand_dims(delayed_stims, axis=1)
129
+ data_whole = np.concatenate((data_whole, delayed_stims), axis=1)
130
+ columns.append("delayed_stimulations")
131
+
132
  print("Saving output...")
133
  # Output the data to a csv file
134
  np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
portiloop/src/demo/phase_demo.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from portiloop.src.demo.offline import run_offline
4
+
5
+
6
+ def on_upload_file(file):
7
+ # Check if file extension is .xdf
8
+ if file.name.split(".")[-1] != "xdf":
9
+ raise gr.Error("Please upload a .xdf file.")
10
+ else:
11
+ return file.name
12
+
13
+
14
+ def main():
15
+ with gr.Blocks(title="Portiloop") as demo:
16
+ gr.Markdown("# Portiloop Demo")
17
+ gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
18
+ gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
19
+
20
+ with gr.Row():
21
+ xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
22
+ xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
23
+
24
+ xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
25
+
26
+ # Make a checkbox group for the options
27
+ detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
28
+
29
+ # Options for phase stimulation
30
+ with gr.Row():
31
+ # Dropwdown for phase
32
+ phase = gr.Dropdown(choices=["Peak", "Fast", "Valley"], value="Peak", label="Phase", interactive=True)
33
+ buffer_time = gr.Slider(0, 1, value=0.3, step=0.01, label="Buffer Time", interactive=True)
34
+
35
+ # Threshold value
36
+ threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
37
+ # Detection Channel
38
+ detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
39
+ # Frequency
40
+ freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
41
+
42
+ with gr.Row():
43
+ output_array = gr.File(label="Output CSV File")
44
+ output_table = gr.Markdown(label="Output Table")
45
+
46
+ run_inference = gr.Button(value="Run Inference")
47
+ run_inference.click(
48
+ fn=run_offline,
49
+ inputs=[
50
+ xdf_file_static,
51
+ detect_filter,
52
+ threshold,
53
+ detect_channel,
54
+ freq,
55
+ phase,
56
+ buffer_time],
57
+ outputs=[output_array, output_table])
58
+
59
+ demo.queue()
60
+ demo.launch(share=False)
61
+
62
+ if __name__ == "__main__":
63
+ main()
portiloop/src/demo/test_offline.py CHANGED
@@ -2,7 +2,7 @@ import itertools
2
  import unittest
3
  from portiloop.src.demo.offline import run_offline
4
  from pathlib import Path
5
-
6
 
7
  class TestOffline(unittest.TestCase):
8
 
@@ -30,16 +30,18 @@ class TestOffline(unittest.TestCase):
30
  self.assertTrue(config['online_filtering'])
31
 
32
  def test_single_option(self):
 
 
 
 
33
  res = list(run_offline(
34
  self.xdf_file,
35
- offline_filtering=True,
36
- online_filtering=True,
37
- online_detection=True,
38
- wamsley=True,
39
- lacourse=True,
40
  threshold=0.5,
41
  channel_num=2,
42
- freq=250))
 
 
43
  print(res)
44
 
45
  def tearDown(self):
 
2
  import unittest
3
  from portiloop.src.demo.offline import run_offline
4
  from pathlib import Path
5
+ import matplotlib.pyplot as plt
6
 
7
  class TestOffline(unittest.TestCase):
8
 
 
30
  self.assertTrue(config['online_filtering'])
31
 
32
  def test_single_option(self):
33
+
34
+ # Test options correspond to an index in the possible checkbox group options
35
+ test_options = [3, 4]
36
+
37
  res = list(run_offline(
38
  self.xdf_file,
39
+ test_options,
 
 
 
 
40
  threshold=0.5,
41
  channel_num=2,
42
+ freq=250,
43
+ stimulation_phase="Peak",
44
+ buffer_time=0.3))
45
  print(res)
46
 
47
  def tearDown(self):
portiloop/src/demo/utils.py CHANGED
@@ -134,18 +134,24 @@ def offline_filter(signal, freq):
134
  def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
135
  # Count the number of spindles detected by each method
136
  online_stimulation_count = np.sum(online_stimulation)
137
- lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
138
- wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
139
-
140
- # Count how many spindles were detected by both online and lacourse
141
- both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
142
- # Count how many spindles were detected by both online and wamsley
143
- both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
 
 
 
 
144
 
145
  # Create markdown table with the results
146
  table = "| Method | Detected spindles | Overlap with Portiloop |\n"
147
  table += "| --- | --- | --- |\n"
148
  table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
149
- table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
150
- table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
 
 
151
  return table
 
134
  def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
135
  # Count the number of spindles detected by each method
136
  online_stimulation_count = np.sum(online_stimulation)
137
+ if lacourse_spindles is not None:
138
+ lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
139
+ # Count how many spindles were detected by both online and lacourse
140
+ both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
141
+
142
+ if wamsley_spindles is not None:
143
+ wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
144
+ # Count how many spindles were detected by both online and wamsley
145
+ both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
146
+
147
+
148
 
149
  # Create markdown table with the results
150
  table = "| Method | Detected spindles | Overlap with Portiloop |\n"
151
  table += "| --- | --- | --- |\n"
152
  table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
153
+ if lacourse_spindles is not None:
154
+ table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
155
+ if wamsley_spindles is not None:
156
+ table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
157
  return table
portiloop/src/stimulation.py CHANGED
@@ -3,6 +3,8 @@ from enum import Enum
3
  import time
4
  from threading import Thread, Lock
5
  from pathlib import Path
 
 
6
 
7
  from portiloop.src import ADS
8
 
@@ -146,20 +148,18 @@ class SleepSpindleRealTimeStimulator(Stimulator):
146
 
147
  # Class that delays stimulation to always stimulate peak or through
148
  class UpStateDelayer:
149
- def __init__(self, sample_freq, spindle_freq, peak, time_to_buffer):
150
  '''
151
  args:
152
  sample_freq: int -> Sampling frequency of signal in Hz
153
  time_to_wait: float -> Time to wait to build buffer in seconds
154
  '''
155
  # Get number of timesteps for a whole spindle
156
- self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
157
  self.sample_freq = sample_freq
158
- self.buffer_size = 1.5 * self.spindle_timesteps
159
  self.peak = peak
160
  self.buffer = []
161
  self.time_to_buffer = time_to_buffer
162
- self.stimulate = None
163
 
164
  self.state = States.NO_SPINDLE
165
 
@@ -192,10 +192,39 @@ class UpStateDelayer:
192
  return True
193
  return False
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def detected(self):
196
  if self.state == States.NO_SPINDLE:
197
  self.state = States.BUFFERING
198
- self.time_started = time.time()
199
 
200
  def compute_time_to_wait(self):
201
  """
@@ -208,8 +237,27 @@ class UpStateDelayer:
208
  # Returns the index of the last peak in the buffer
209
  peaks, _ = find_peaks(self.buffer, prominence=1)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  # Compute the time until next peak and return it
212
- return (len(self.buffer) - peaks[-1]) * (1 / self.sample_freq)
 
 
 
213
 
214
  class States(Enum):
215
  NO_SPINDLE = 0
 
3
  import time
4
  from threading import Thread, Lock
5
  from pathlib import Path
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
 
9
  from portiloop.src import ADS
10
 
 
148
 
149
  # Class that delays stimulation to always stimulate peak or through
150
  class UpStateDelayer:
151
+ def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None):
152
  '''
153
  args:
154
  sample_freq: int -> Sampling frequency of signal in Hz
155
  time_to_wait: float -> Time to wait to build buffer in seconds
156
  '''
157
  # Get number of timesteps for a whole spindle
 
158
  self.sample_freq = sample_freq
 
159
  self.peak = peak
160
  self.buffer = []
161
  self.time_to_buffer = time_to_buffer
162
+ self.stimulate = stimulate
163
 
164
  self.state = States.NO_SPINDLE
165
 
 
192
  return True
193
  return False
194
 
195
+ def step_timesteps(self, point):
196
+ '''
197
+ Step the delayer, ads a point to buffer if necessary.
198
+ Returns True if stimulation is actually done
199
+ '''
200
+ if self.state == States.NO_SPINDLE:
201
+ return False
202
+ elif self.state == States.BUFFERING:
203
+ self.buffer.append(point)
204
+ # If we are done buffering, move on to the waiting stage
205
+ if len(self.buffer) >= self.time_to_buffer * self.sample_freq:
206
+ # Compute the necessary time to wait
207
+ self.time_to_wait = self.compute_time_to_wait()
208
+ self.state = States.DELAYING
209
+ self.buffer = []
210
+ self.delaying_counter = 0
211
+ return False
212
+ elif self.state == States.DELAYING:
213
+ # Check if we are done delaying
214
+ self.delaying_counter += 1
215
+ if self.delaying_counter >= self.time_to_wait * self.sample_freq:
216
+ # Actually stimulate the patient after the delay
217
+ if self.stimulate is not None:
218
+ self.stimulate()
219
+ # Reset state
220
+ self.time_to_wait = -1
221
+ self.state = States.NO_SPINDLE
222
+ return True
223
+ return False
224
+
225
  def detected(self):
226
  if self.state == States.NO_SPINDLE:
227
  self.state = States.BUFFERING
 
228
 
229
  def compute_time_to_wait(self):
230
  """
 
237
  # Returns the index of the last peak in the buffer
238
  peaks, _ = find_peaks(self.buffer, prominence=1)
239
 
240
+ # Make a figure to show the peaks
241
+ if False:
242
+ plt.figure()
243
+ plt.plot(self.buffer)
244
+ for peak in peaks:
245
+ plt.axvline(x=peak)
246
+ plt.plot(np.zeros_like(self.buffer), "--", color="gray")
247
+ plt.show()
248
+
249
+ if len(peaks) == 0:
250
+ print("No peaks found, increase buffer size")
251
+ return (self.sample_freq / 10) * (1.0 / self.sample_freq)
252
+
253
+ # Compute average distance between each peak
254
+ avg_dist = np.mean(np.diff(peaks))
255
+
256
  # Compute the time until next peak and return it
257
+ if (avg_dist < len(self.buffer) - peaks[-1]):
258
+ print("Average distance between peaks is smaller than the time to last peak, decrease buffer size")
259
+ return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq)
260
+ return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq)
261
 
262
  class States(Enum):
263
  NO_SPINDLE = 0