Milo Sobral commited on
Commit
902bac9
·
unverified ·
2 Parent(s): d4bce82 bf3350c

Merge pull request #5 from Portiloop/milo/filetypes_and_staging

Browse files
portiloop/src/demo/offline.py CHANGED
@@ -1,13 +1,12 @@
1
- import matplotlib.pyplot as plt
2
  import numpy as np
3
  from portiloop.src.detection import SleepSpindleRealTimeDetector
4
- plt.switch_backend('agg')
5
  from portiloop.src.processing import FilterPipeline
6
- from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
7
  import gradio as gr
8
 
9
 
10
- def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
11
  # Get the options from the checkbox group
12
  offline_filtering = 0 in detect_filter_opts
13
  lacourse = 1 in detect_filter_opts
@@ -30,6 +29,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
30
  # Read the xdf file to a numpy array
31
  print("Loading xdf file...")
32
  data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
 
33
  # Do the offline filtering of the data
34
  if offline_filtering:
35
  print("Filtering offline...")
@@ -39,13 +39,18 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
39
  data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
40
  columns.append("offline_filtered_signal")
41
 
 
 
 
 
 
42
  # Do Wamsley's method
43
  if wamsley:
44
  print("Running Wamsley detection...")
45
  wamsley_data = offline_detect("Wamsley", \
46
  data_whole[:, columns.index("offline_filtered_signal")],\
47
  data_whole[:, columns.index("time_stamps")],\
48
- freq)
49
  wamsley_data = np.expand_dims(wamsley_data, axis=1)
50
  data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
51
  columns.append("wamsley_spindles")
@@ -56,7 +61,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
56
  lacourse_data = offline_detect("Lacourse", \
57
  data_whole[:, columns.index("offline_filtered_signal")],\
58
  data_whole[:, columns.index("time_stamps")],\
59
- freq)
60
  lacourse_data = np.expand_dims(lacourse_data, axis=1)
61
  data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
62
  columns.append("lacourse_spindles")
@@ -72,12 +77,17 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
72
  if online_detection:
73
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
74
  stimulator = OfflineSleepSpindleRealTimeStimulator()
 
 
 
 
75
 
76
  if online_filtering or online_detection:
77
  print("Running online filtering and detection...")
78
 
79
  points = []
80
  online_activations = []
 
81
 
82
  # Go through the data
83
  for index, point in enumerate(data):
@@ -93,6 +103,13 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
93
  # Detect the spindles
94
  result = detector.detect([filtered_point])
95
 
 
 
 
 
 
 
 
96
  # Stimulate if necessary
97
  stim = stimulator.stimulate(result)
98
  if stim:
@@ -112,6 +129,12 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
112
  data_whole = np.concatenate((data_whole, online_activations), axis=1)
113
  columns.append("online_stimulations")
114
 
 
 
 
 
 
 
115
  print("Saving output...")
116
  # Output the data to a csv file
117
  np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
 
 
1
  import numpy as np
2
  from portiloop.src.detection import SleepSpindleRealTimeDetector
3
+ from portiloop.src.stimulation import UpStateDelayer
4
  from portiloop.src.processing import FilterPipeline
5
+ from portiloop.src.demo.utils import compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
6
  import gradio as gr
7
 
8
 
9
+ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stimulation_phase="Fast", buffer_time=0.25):
10
  # Get the options from the checkbox group
11
  offline_filtering = 0 in detect_filter_opts
12
  lacourse = 1 in detect_filter_opts
 
29
  # Read the xdf file to a numpy array
30
  print("Loading xdf file...")
31
  data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
32
+
33
  # Do the offline filtering of the data
34
  if offline_filtering:
35
  print("Filtering offline...")
 
39
  data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
40
  columns.append("offline_filtered_signal")
41
 
42
+ # Do the sleep staging approximation
43
+ if wamsley or lacourse:
44
+ print("Sleep staging...")
45
+ mask = sleep_stage(data_whole[:, columns.index("offline_filtered_signal")], threshold=150, group_size=100)
46
+
47
  # Do Wamsley's method
48
  if wamsley:
49
  print("Running Wamsley detection...")
50
  wamsley_data = offline_detect("Wamsley", \
51
  data_whole[:, columns.index("offline_filtered_signal")],\
52
  data_whole[:, columns.index("time_stamps")],\
53
+ freq, mask)
54
  wamsley_data = np.expand_dims(wamsley_data, axis=1)
55
  data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
56
  columns.append("wamsley_spindles")
 
61
  lacourse_data = offline_detect("Lacourse", \
62
  data_whole[:, columns.index("offline_filtered_signal")],\
63
  data_whole[:, columns.index("time_stamps")],\
64
+ freq, mask)
65
  lacourse_data = np.expand_dims(lacourse_data, axis=1)
66
  data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
67
  columns.append("lacourse_spindles")
 
77
  if online_detection:
78
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
79
  stimulator = OfflineSleepSpindleRealTimeStimulator()
80
+ if stimulation_phase != "Fast":
81
+ stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
82
+ stimulator.add_delayer(stimulation_delayer)
83
+
84
 
85
  if online_filtering or online_detection:
86
  print("Running online filtering and detection...")
87
 
88
  points = []
89
  online_activations = []
90
+ delayed_stims = []
91
 
92
  # Go through the data
93
  for index, point in enumerate(data):
 
103
  # Detect the spindles
104
  result = detector.detect([filtered_point])
105
 
106
+ if stimulation_phase != "Fast":
107
+ delayed_stim = stimulation_delayer.step_timesteps(filtered_point[0])
108
+ if delayed_stim:
109
+ delayed_stims.append(1)
110
+ else:
111
+ delayed_stims.append(0)
112
+
113
  # Stimulate if necessary
114
  stim = stimulator.stimulate(result)
115
  if stim:
 
129
  data_whole = np.concatenate((data_whole, online_activations), axis=1)
130
  columns.append("online_stimulations")
131
 
132
+ if stimulation_phase != "Fast":
133
+ delayed_stims = np.array(delayed_stims)
134
+ delayed_stims = np.expand_dims(delayed_stims, axis=1)
135
+ data_whole = np.concatenate((data_whole, delayed_stims), axis=1)
136
+ columns.append("delayed_stimulations")
137
+
138
  print("Saving output...")
139
  # Output the data to a csv file
140
  np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
portiloop/src/demo/phase_demo.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from portiloop.src.demo.offline import run_offline
4
+
5
+
6
+ def on_upload_file(file):
7
+ # Check if file extension is .xdf
8
+ if file.name.split(".")[-1] != "xdf":
9
+ raise gr.Error("Please upload a .xdf file.")
10
+ else:
11
+ return file.name
12
+
13
+
14
+ def main():
15
+ with gr.Blocks(title="Portiloop") as demo:
16
+ gr.Markdown("# Portiloop Demo")
17
+ gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
18
+ gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
19
+
20
+ with gr.Row():
21
+ xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
22
+ xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
23
+
24
+ xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
25
+
26
+ # Make a checkbox group for the options
27
+ detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
28
+
29
+ # Options for phase stimulation
30
+ with gr.Row():
31
+ # Dropwdown for phase
32
+ phase = gr.Dropdown(choices=["Peak", "Fast", "Valley"], value="Peak", label="Phase", interactive=True)
33
+ buffer_time = gr.Slider(0, 1, value=0.3, step=0.01, label="Buffer Time", interactive=True)
34
+
35
+ # Threshold value
36
+ threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
37
+ # Detection Channel
38
+ detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
39
+ # Frequency
40
+ freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
41
+
42
+ with gr.Row():
43
+ output_array = gr.File(label="Output CSV File")
44
+ output_table = gr.Markdown(label="Output Table")
45
+
46
+ run_inference = gr.Button(value="Run Inference")
47
+ run_inference.click(
48
+ fn=run_offline,
49
+ inputs=[
50
+ xdf_file_static,
51
+ detect_filter,
52
+ threshold,
53
+ detect_channel,
54
+ freq,
55
+ phase,
56
+ buffer_time],
57
+ outputs=[output_array, output_table])
58
+
59
+ demo.queue()
60
+ demo.launch(share=False)
61
+
62
+ if __name__ == "__main__":
63
+ main()
portiloop/src/demo/test_offline.py CHANGED
@@ -2,7 +2,9 @@ import itertools
2
  import unittest
3
  from portiloop.src.demo.offline import run_offline
4
  from pathlib import Path
 
5
 
 
6
 
7
  class TestOffline(unittest.TestCase):
8
 
@@ -21,7 +23,7 @@ class TestOffline(unittest.TestCase):
21
  all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
22
  all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
23
  self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
24
- self.xdf_file = Path(__file__).parents[3] / "test_xdf" / "test_file.xdf"
25
 
26
 
27
  def test_all_options(self):
@@ -30,17 +32,20 @@ class TestOffline(unittest.TestCase):
30
  self.assertTrue(config['online_filtering'])
31
 
32
  def test_single_option(self):
 
 
 
 
33
  res = list(run_offline(
34
  self.xdf_file,
35
- offline_filtering=True,
36
- online_filtering=True,
37
- online_detection=True,
38
- wamsley=True,
39
- lacourse=True,
40
  threshold=0.5,
41
  channel_num=2,
42
- freq=250))
 
 
43
  print(res)
 
44
 
45
  def tearDown(self):
46
  pass
 
2
  import unittest
3
  from portiloop.src.demo.offline import run_offline
4
  from pathlib import Path
5
+ import matplotlib.pyplot as plt
6
 
7
+ from portiloop.src.demo.utils import sleep_stage, xdf2array
8
 
9
  class TestOffline(unittest.TestCase):
10
 
 
23
  all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
24
  all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
25
  self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
26
+ self.xdf_file = Path(__file__).parents[3] / "test_file.xdf"
27
 
28
 
29
  def test_all_options(self):
 
32
  self.assertTrue(config['online_filtering'])
33
 
34
  def test_single_option(self):
35
+
36
+ # Test options correspond to an index in the possible checkbox group options
37
+ test_options = [0, 1, 2]
38
+
39
  res = list(run_offline(
40
  self.xdf_file,
41
+ test_options,
 
 
 
 
42
  threshold=0.5,
43
  channel_num=2,
44
+ freq=250,
45
+ stimulation_phase="Peak",
46
+ buffer_time=0.3))
47
  print(res)
48
+ pass
49
 
50
  def tearDown(self):
51
  pass
portiloop/src/demo/utils.py CHANGED
@@ -13,6 +13,32 @@ STREAM_NAMES = {
13
  }
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class OfflineSleepSpindleRealTimeStimulator(Stimulator):
17
  def __init__(self):
18
  self.last_detected_ts = time.time()
@@ -87,15 +113,19 @@ def xdf2array(xdf_path, channel):
87
  return np.array(csv_list), columns
88
 
89
 
90
- def offline_detect(method, data, timesteps, freq):
 
 
 
91
  # Get the spindle data from the offline methods
92
  time = np.arange(0, len(data)) / freq
 
93
  if method == "Lacourse":
94
  detector = DetectSpindle(method='Lacourse2018')
95
- spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
96
  elif method == "Wamsley":
97
  detector = DetectSpindle(method='Wamsley2012')
98
- spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
99
  else:
100
  raise ValueError("Invalid method")
101
 
@@ -134,18 +164,25 @@ def offline_filter(signal, freq):
134
  def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
135
  # Count the number of spindles detected by each method
136
  online_stimulation_count = np.sum(online_stimulation)
137
- lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
138
- wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
139
-
140
- # Count how many spindles were detected by both online and lacourse
141
- both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
142
- # Count how many spindles were detected by both online and wamsley
143
- both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
 
 
 
 
144
 
145
  # Create markdown table with the results
146
  table = "| Method | Detected spindles | Overlap with Portiloop |\n"
147
  table += "| --- | --- | --- |\n"
148
  table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
149
- table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
150
- table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
 
 
151
  return table
 
 
13
  }
14
 
15
 
16
+ def sleep_stage(data, threshold=150, group_size=2):
17
+ """Sleep stage approximation using a threshold and a group size.
18
+ Returns a numpy array containing all indices in the input data which CAN be used for offline detection.
19
+ These indices can then be used to reconstruct the signal from the original data.
20
+ """
21
+ # Find all indexes where the signal is above or below the threshold
22
+ above = np.where(data > threshold)
23
+ below = np.where(data < -threshold)
24
+ indices = np.concatenate((above, below), axis=1)[0]
25
+
26
+ indices = np.sort(indices)
27
+ # Get all the indices where the difference between two consecutive indices is larger than 100
28
+ groups = np.where(np.diff(indices) <= group_size)[0] + 1
29
+ # Get the important indices
30
+ important_indices = indices[groups]
31
+ # Get all the indices between the important indices
32
+ group_filler = [np.arange(indices[groups[n] - 1] + 1, index) for n, index in enumerate(important_indices)]
33
+ # Create flat array from fillers
34
+ group_filler = np.concatenate(group_filler)
35
+ # Append all group fillers to the indices
36
+ masked_indices = np.sort(np.concatenate((indices, group_filler)))
37
+ unmasked_indices = np.setdiff1d(np.arange(len(data)), masked_indices)
38
+
39
+ return unmasked_indices
40
+
41
+
42
  class OfflineSleepSpindleRealTimeStimulator(Stimulator):
43
  def __init__(self):
44
  self.last_detected_ts = time.time()
 
113
  return np.array(csv_list), columns
114
 
115
 
116
+ def offline_detect(method, data, timesteps, freq, mask):
117
+ # Extract only the interesting elements from the mask
118
+ data_masked = data[mask]
119
+
120
  # Get the spindle data from the offline methods
121
  time = np.arange(0, len(data)) / freq
122
+ time_masked = time[mask]
123
  if method == "Lacourse":
124
  detector = DetectSpindle(method='Lacourse2018')
125
+ spindles, _, _ = detect_Lacourse2018(data_masked, freq, time_masked, detector)
126
  elif method == "Wamsley":
127
  detector = DetectSpindle(method='Wamsley2012')
128
+ spindles, _, _ = detect_Wamsley2012(data_masked, freq, time_masked, detector)
129
  else:
130
  raise ValueError("Invalid method")
131
 
 
164
  def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
165
  # Count the number of spindles detected by each method
166
  online_stimulation_count = np.sum(online_stimulation)
167
+ if lacourse_spindles is not None:
168
+ lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
169
+ # Count how many spindles were detected by both online and lacourse
170
+ both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
171
+
172
+ if wamsley_spindles is not None:
173
+ wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
174
+ # Count how many spindles were detected by both online and wamsley
175
+ both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
176
+
177
+
178
 
179
  # Create markdown table with the results
180
  table = "| Method | Detected spindles | Overlap with Portiloop |\n"
181
  table += "| --- | --- | --- |\n"
182
  table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
183
+ if lacourse_spindles is not None:
184
+ table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
185
+ if wamsley_spindles is not None:
186
+ table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
187
  return table
188
+
portiloop/src/stimulation.py CHANGED
@@ -3,6 +3,8 @@ from enum import Enum
3
  import time
4
  from threading import Thread, Lock
5
  from pathlib import Path
 
 
6
 
7
  from portiloop.src import ADS
8
 
@@ -146,20 +148,18 @@ class SleepSpindleRealTimeStimulator(Stimulator):
146
 
147
  # Class that delays stimulation to always stimulate peak or through
148
  class UpStateDelayer:
149
- def __init__(self, sample_freq, spindle_freq, peak, time_to_buffer):
150
  '''
151
  args:
152
  sample_freq: int -> Sampling frequency of signal in Hz
153
  time_to_wait: float -> Time to wait to build buffer in seconds
154
  '''
155
  # Get number of timesteps for a whole spindle
156
- self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
157
  self.sample_freq = sample_freq
158
- self.buffer_size = 1.5 * self.spindle_timesteps
159
  self.peak = peak
160
  self.buffer = []
161
  self.time_to_buffer = time_to_buffer
162
- self.stimulate = None
163
 
164
  self.state = States.NO_SPINDLE
165
 
@@ -192,10 +192,39 @@ class UpStateDelayer:
192
  return True
193
  return False
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def detected(self):
196
  if self.state == States.NO_SPINDLE:
197
  self.state = States.BUFFERING
198
- self.time_started = time.time()
199
 
200
  def compute_time_to_wait(self):
201
  """
@@ -208,8 +237,27 @@ class UpStateDelayer:
208
  # Returns the index of the last peak in the buffer
209
  peaks, _ = find_peaks(self.buffer, prominence=1)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  # Compute the time until next peak and return it
212
- return (len(self.buffer) - peaks[-1]) * (1 / self.sample_freq)
 
 
 
213
 
214
  class States(Enum):
215
  NO_SPINDLE = 0
 
3
  import time
4
  from threading import Thread, Lock
5
  from pathlib import Path
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
 
9
  from portiloop.src import ADS
10
 
 
148
 
149
  # Class that delays stimulation to always stimulate peak or through
150
  class UpStateDelayer:
151
+ def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None):
152
  '''
153
  args:
154
  sample_freq: int -> Sampling frequency of signal in Hz
155
  time_to_wait: float -> Time to wait to build buffer in seconds
156
  '''
157
  # Get number of timesteps for a whole spindle
 
158
  self.sample_freq = sample_freq
 
159
  self.peak = peak
160
  self.buffer = []
161
  self.time_to_buffer = time_to_buffer
162
+ self.stimulate = stimulate
163
 
164
  self.state = States.NO_SPINDLE
165
 
 
192
  return True
193
  return False
194
 
195
+ def step_timesteps(self, point):
196
+ '''
197
+ Step the delayer, ads a point to buffer if necessary.
198
+ Returns True if stimulation is actually done
199
+ '''
200
+ if self.state == States.NO_SPINDLE:
201
+ return False
202
+ elif self.state == States.BUFFERING:
203
+ self.buffer.append(point)
204
+ # If we are done buffering, move on to the waiting stage
205
+ if len(self.buffer) >= self.time_to_buffer * self.sample_freq:
206
+ # Compute the necessary time to wait
207
+ self.time_to_wait = self.compute_time_to_wait()
208
+ self.state = States.DELAYING
209
+ self.buffer = []
210
+ self.delaying_counter = 0
211
+ return False
212
+ elif self.state == States.DELAYING:
213
+ # Check if we are done delaying
214
+ self.delaying_counter += 1
215
+ if self.delaying_counter >= self.time_to_wait * self.sample_freq:
216
+ # Actually stimulate the patient after the delay
217
+ if self.stimulate is not None:
218
+ self.stimulate()
219
+ # Reset state
220
+ self.time_to_wait = -1
221
+ self.state = States.NO_SPINDLE
222
+ return True
223
+ return False
224
+
225
  def detected(self):
226
  if self.state == States.NO_SPINDLE:
227
  self.state = States.BUFFERING
 
228
 
229
  def compute_time_to_wait(self):
230
  """
 
237
  # Returns the index of the last peak in the buffer
238
  peaks, _ = find_peaks(self.buffer, prominence=1)
239
 
240
+ # Make a figure to show the peaks
241
+ if False:
242
+ plt.figure()
243
+ plt.plot(self.buffer)
244
+ for peak in peaks:
245
+ plt.axvline(x=peak)
246
+ plt.plot(np.zeros_like(self.buffer), "--", color="gray")
247
+ plt.show()
248
+
249
+ if len(peaks) == 0:
250
+ print("No peaks found, increase buffer size")
251
+ return (self.sample_freq / 10) * (1.0 / self.sample_freq)
252
+
253
+ # Compute average distance between each peak
254
+ avg_dist = np.mean(np.diff(peaks))
255
+
256
  # Compute the time until next peak and return it
257
+ if (avg_dist < len(self.buffer) - peaks[-1]):
258
+ print("Average distance between peaks is smaller than the time to last peak, decrease buffer size")
259
+ return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq)
260
+ return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq)
261
 
262
  class States(Enum):
263
  NO_SPINDLE = 0