Milo Sobral commited on
Commit
bf3350c
·
1 Parent(s): 986653d

Done with sleep staging but needs checking

Browse files
portiloop/src/demo/offline.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
  from portiloop.src.detection import SleepSpindleRealTimeDetector
3
  from portiloop.src.stimulation import UpStateDelayer
4
  from portiloop.src.processing import FilterPipeline
5
- from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
6
  import gradio as gr
7
 
8
 
@@ -29,6 +29,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stim
29
  # Read the xdf file to a numpy array
30
  print("Loading xdf file...")
31
  data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
 
32
  # Do the offline filtering of the data
33
  if offline_filtering:
34
  print("Filtering offline...")
@@ -38,13 +39,18 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stim
38
  data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
39
  columns.append("offline_filtered_signal")
40
 
 
 
 
 
 
41
  # Do Wamsley's method
42
  if wamsley:
43
  print("Running Wamsley detection...")
44
  wamsley_data = offline_detect("Wamsley", \
45
  data_whole[:, columns.index("offline_filtered_signal")],\
46
  data_whole[:, columns.index("time_stamps")],\
47
- freq)
48
  wamsley_data = np.expand_dims(wamsley_data, axis=1)
49
  data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
50
  columns.append("wamsley_spindles")
@@ -55,7 +61,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stim
55
  lacourse_data = offline_detect("Lacourse", \
56
  data_whole[:, columns.index("offline_filtered_signal")],\
57
  data_whole[:, columns.index("time_stamps")],\
58
- freq)
59
  lacourse_data = np.expand_dims(lacourse_data, axis=1)
60
  data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
61
  columns.append("lacourse_spindles")
 
2
  from portiloop.src.detection import SleepSpindleRealTimeDetector
3
  from portiloop.src.stimulation import UpStateDelayer
4
  from portiloop.src.processing import FilterPipeline
5
+ from portiloop.src.demo.utils import compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
6
  import gradio as gr
7
 
8
 
 
29
  # Read the xdf file to a numpy array
30
  print("Loading xdf file...")
31
  data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
32
+
33
  # Do the offline filtering of the data
34
  if offline_filtering:
35
  print("Filtering offline...")
 
39
  data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
40
  columns.append("offline_filtered_signal")
41
 
42
+ # Do the sleep staging approximation
43
+ if wamsley or lacourse:
44
+ print("Sleep staging...")
45
+ mask = sleep_stage(data_whole[:, columns.index("offline_filtered_signal")], threshold=150, group_size=100)
46
+
47
  # Do Wamsley's method
48
  if wamsley:
49
  print("Running Wamsley detection...")
50
  wamsley_data = offline_detect("Wamsley", \
51
  data_whole[:, columns.index("offline_filtered_signal")],\
52
  data_whole[:, columns.index("time_stamps")],\
53
+ freq, mask)
54
  wamsley_data = np.expand_dims(wamsley_data, axis=1)
55
  data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
56
  columns.append("wamsley_spindles")
 
61
  lacourse_data = offline_detect("Lacourse", \
62
  data_whole[:, columns.index("offline_filtered_signal")],\
63
  data_whole[:, columns.index("time_stamps")],\
64
+ freq, mask)
65
  lacourse_data = np.expand_dims(lacourse_data, axis=1)
66
  data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
67
  columns.append("lacourse_spindles")
portiloop/src/demo/test_offline.py CHANGED
@@ -4,6 +4,8 @@ from portiloop.src.demo.offline import run_offline
4
  from pathlib import Path
5
  import matplotlib.pyplot as plt
6
 
 
 
7
  class TestOffline(unittest.TestCase):
8
 
9
  def setUp(self):
@@ -21,7 +23,7 @@ class TestOffline(unittest.TestCase):
21
  all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
22
  all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
23
  self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
24
- self.xdf_file = Path(__file__).parents[3] / "test_xdf" / "test_file.xdf"
25
 
26
 
27
  def test_all_options(self):
@@ -32,7 +34,7 @@ class TestOffline(unittest.TestCase):
32
  def test_single_option(self):
33
 
34
  # Test options correspond to an index in the possible checkbox group options
35
- test_options = [3, 4]
36
 
37
  res = list(run_offline(
38
  self.xdf_file,
@@ -43,6 +45,7 @@ class TestOffline(unittest.TestCase):
43
  stimulation_phase="Peak",
44
  buffer_time=0.3))
45
  print(res)
 
46
 
47
  def tearDown(self):
48
  pass
 
4
  from pathlib import Path
5
  import matplotlib.pyplot as plt
6
 
7
+ from portiloop.src.demo.utils import sleep_stage, xdf2array
8
+
9
  class TestOffline(unittest.TestCase):
10
 
11
  def setUp(self):
 
23
  all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
24
  all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
25
  self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
26
+ self.xdf_file = Path(__file__).parents[3] / "test_file.xdf"
27
 
28
 
29
  def test_all_options(self):
 
34
  def test_single_option(self):
35
 
36
  # Test options correspond to an index in the possible checkbox group options
37
+ test_options = [0, 1, 2]
38
 
39
  res = list(run_offline(
40
  self.xdf_file,
 
45
  stimulation_phase="Peak",
46
  buffer_time=0.3))
47
  print(res)
48
+ pass
49
 
50
  def tearDown(self):
51
  pass
portiloop/src/demo/utils.py CHANGED
@@ -13,6 +13,32 @@ STREAM_NAMES = {
13
  }
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class OfflineSleepSpindleRealTimeStimulator(Stimulator):
17
  def __init__(self):
18
  self.last_detected_ts = time.time()
@@ -87,15 +113,19 @@ def xdf2array(xdf_path, channel):
87
  return np.array(csv_list), columns
88
 
89
 
90
- def offline_detect(method, data, timesteps, freq):
 
 
 
91
  # Get the spindle data from the offline methods
92
  time = np.arange(0, len(data)) / freq
 
93
  if method == "Lacourse":
94
  detector = DetectSpindle(method='Lacourse2018')
95
- spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
96
  elif method == "Wamsley":
97
  detector = DetectSpindle(method='Wamsley2012')
98
- spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
99
  else:
100
  raise ValueError("Invalid method")
101
 
@@ -155,3 +185,4 @@ def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles
155
  if wamsley_spindles is not None:
156
  table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
157
  return table
 
 
13
  }
14
 
15
 
16
+ def sleep_stage(data, threshold=150, group_size=2):
17
+ """Sleep stage approximation using a threshold and a group size.
18
+ Returns a numpy array containing all indices in the input data which CAN be used for offline detection.
19
+ These indices can then be used to reconstruct the signal from the original data.
20
+ """
21
+ # Find all indexes where the signal is above or below the threshold
22
+ above = np.where(data > threshold)
23
+ below = np.where(data < -threshold)
24
+ indices = np.concatenate((above, below), axis=1)[0]
25
+
26
+ indices = np.sort(indices)
27
+ # Get all the indices where the difference between two consecutive indices is larger than 100
28
+ groups = np.where(np.diff(indices) <= group_size)[0] + 1
29
+ # Get the important indices
30
+ important_indices = indices[groups]
31
+ # Get all the indices between the important indices
32
+ group_filler = [np.arange(indices[groups[n] - 1] + 1, index) for n, index in enumerate(important_indices)]
33
+ # Create flat array from fillers
34
+ group_filler = np.concatenate(group_filler)
35
+ # Append all group fillers to the indices
36
+ masked_indices = np.sort(np.concatenate((indices, group_filler)))
37
+ unmasked_indices = np.setdiff1d(np.arange(len(data)), masked_indices)
38
+
39
+ return unmasked_indices
40
+
41
+
42
  class OfflineSleepSpindleRealTimeStimulator(Stimulator):
43
  def __init__(self):
44
  self.last_detected_ts = time.time()
 
113
  return np.array(csv_list), columns
114
 
115
 
116
+ def offline_detect(method, data, timesteps, freq, mask):
117
+ # Extract only the interesting elements from the mask
118
+ data_masked = data[mask]
119
+
120
  # Get the spindle data from the offline methods
121
  time = np.arange(0, len(data)) / freq
122
+ time_masked = time[mask]
123
  if method == "Lacourse":
124
  detector = DetectSpindle(method='Lacourse2018')
125
+ spindles, _, _ = detect_Lacourse2018(data_masked, freq, time_masked, detector)
126
  elif method == "Wamsley":
127
  detector = DetectSpindle(method='Wamsley2012')
128
+ spindles, _, _ = detect_Wamsley2012(data_masked, freq, time_masked, detector)
129
  else:
130
  raise ValueError("Invalid method")
131
 
 
185
  if wamsley_spindles is not None:
186
  table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
187
  return table
188
+