Milo Sobral commited on
Commit
b38a4ed
Β·
unverified Β·
2 Parent(s): 39c06fc 111f264

Merge pull request #3 from Portiloop/milo/no_portiloop

Browse files
.gitignore CHANGED
@@ -2,6 +2,11 @@
2
  .vscode/
3
  .idea/
4
 
 
 
 
 
 
5
  # Vagrant
6
  .vagrant/
7
 
@@ -122,4 +127,4 @@ venv.bak/
122
  # mypy
123
  .mypy_cache/
124
  .dmypy.json
125
- dmypy.json
 
2
  .vscode/
3
  .idea/
4
 
5
+ # Output from the demo
6
+ output.csv
7
+ # Any xdf file used for testing
8
+ *.xdf
9
+
10
  # Vagrant
11
  .vagrant/
12
 
 
127
  # mypy
128
  .mypy_cache/
129
  .dmypy.json
130
+ dmypy.json
portiloop/notebooks/tests.ipynb CHANGED
@@ -2,35 +2,60 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "16651843",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "source": [
12
- "from portiloop.capture import Capture\n",
13
- "from portiloop.detection import SleepSpindleRealTimeDetector\n",
14
- "from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
15
  "\n",
16
  "my_detector_class = SleepSpindleRealTimeDetector # you may want to implement yours\n",
17
  "my_stimulator_class = SleepSpindleRealTimeStimulator # you may also want to implement yours\n",
18
  "\n",
19
  "cap = Capture(detector_cls=my_detector_class, stimulator_cls=my_stimulator_class)"
20
  ]
21
- },
22
- {
23
- "cell_type": "code",
24
- "execution_count": null,
25
- "id": "cded6bbc",
26
- "metadata": {},
27
- "outputs": [],
28
- "source": []
29
  }
30
  ],
31
  "metadata": {
32
  "kernelspec": {
33
- "display_name": "Python 3",
34
  "language": "python",
35
  "name": "python3"
36
  },
@@ -44,7 +69,12 @@
44
  "name": "python",
45
  "nbconvert_exporter": "python",
46
  "pygments_lexer": "ipython3",
47
- "version": "3.7.3"
 
 
 
 
 
48
  }
49
  },
50
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "16651843",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
+ "outputs": [
11
+ {
12
+ "data": {
13
+ "application/vnd.jupyter.widget-view+json": {
14
+ "model_id": "f46843d136af4c79a73841b997fa3284",
15
+ "version_major": 2,
16
+ "version_minor": 0
17
+ },
18
+ "text/plain": [
19
+ "VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH2'), Label(value='CH3'), Label(value='CH4'…"
20
+ ]
21
+ },
22
+ "metadata": {},
23
+ "output_type": "display_data"
24
+ },
25
+ {
26
+ "name": "stderr",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Exception in thread Thread-3:\n",
30
+ "Traceback (most recent call last):\n",
31
+ " File \"C:\\Users\\milos\\AppData\\Local\\Programs\\Python\\Python37\\lib\\threading.py\", line 917, in _bootstrap_inner\n",
32
+ " self.run()\n",
33
+ " File \"C:\\Users\\milos\\AppData\\Local\\Programs\\Python\\Python37\\lib\\threading.py\", line 865, in run\n",
34
+ " self._target(*self._args, **self._kwargs)\n",
35
+ " File \"c:\\users\\milos\\documents\\github\\portiloop-software\\portiloop\\src\\capture.py\", line 927, in start_capture\n",
36
+ " detector = detector_cls(threshold, channel=channel) if detector_cls is not None else None\n",
37
+ " File \"c:\\users\\milos\\documents\\github\\portiloop-software\\portiloop\\src\\detection.py\", line 56, in __init__\n",
38
+ " self.interpreters.append(edgetpu.make_interpreter(model_path))\n",
39
+ "NameError: name 'edgetpu' is not defined\n",
40
+ "\n"
41
+ ]
42
+ }
43
+ ],
44
  "source": [
45
+ "from portiloop.src.capture import Capture\n",
46
+ "from portiloop.src.detection import SleepSpindleRealTimeDetector\n",
47
+ "from portiloop.src.stimulation import SleepSpindleRealTimeStimulator\n",
48
  "\n",
49
  "my_detector_class = SleepSpindleRealTimeDetector # you may want to implement yours\n",
50
  "my_stimulator_class = SleepSpindleRealTimeStimulator # you may also want to implement yours\n",
51
  "\n",
52
  "cap = Capture(detector_cls=my_detector_class, stimulator_cls=my_stimulator_class)"
53
  ]
 
 
 
 
 
 
 
 
54
  }
55
  ],
56
  "metadata": {
57
  "kernelspec": {
58
+ "display_name": "Python 3.7.0b2 ('venv': venv)",
59
  "language": "python",
60
  "name": "python3"
61
  },
 
69
  "name": "python",
70
  "nbconvert_exporter": "python",
71
  "pygments_lexer": "ipython3",
72
+ "version": "3.7.0b2"
73
+ },
74
+ "vscode": {
75
+ "interpreter": {
76
+ "hash": "770b75f59a9c369b74cb4cefed633f309c892e213e0ed32eed0937c0a5627480"
77
+ }
78
  }
79
  },
80
  "nbformat": 4,
portiloop/recordings/test_recording.csv ADDED
The diff for this file is too large to render. See raw diff
 
portiloop/src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ADS = False
portiloop/{capture.py β†’ src/capture.py} RENAMED
@@ -1,345 +1,31 @@
1
- import os
2
- import sys
3
 
4
  from time import sleep
5
  import time
6
  import numpy as np
7
- import os
8
  from copy import deepcopy
9
- from pathlib import Path
10
- from datetime import datetime, timedelta
11
  import multiprocessing as mp
12
  import warnings
13
- import shutil
14
  from threading import Thread, Lock
15
- import alsaaudio
 
 
 
 
 
16
 
17
- from EDFlib.edfwriter import EDFwriter
18
- from scipy.signal import firwin
19
 
20
- from portilooplot.jupyter_plot import ProgressPlot
21
- from portiloop.hardware.frontend import Frontend
22
- from portiloop.hardware.leds import LEDs, Color
23
 
 
 
 
24
  from IPython.display import clear_output, display
25
  import ipywidgets as widgets
26
 
27
 
28
- DEFAULT_FRONTEND_CONFIG = [
29
- # nomenclature: name [default setting] [bits 7-0] : description
30
- # Read only ID:
31
- 0x3E, # ID [xx] [REV_ID[2:0], 1, DEV_ID[1:0], NU_CH[1:0]] : (RO)
32
- # Global Settings Across Channels:
33
- 0x96, # CONFIG1 [96] [1, DAISY_EN(bar), CLK_EN, 1, 0, DR[2:0]] : Datarate = 250 SPS
34
- 0xC0, # CONFIG2 [C0] [1, 1, 0, INT_CAL, 0, CAL_AMP0, CAL_FREQ[1:0]] : No tests
35
- 0x60, # CONFIG3 [60] [PD_REFBUF(bar), 1, 1, BIAS_MEAS, BIASREF_INT, PD_BIAS(bar), BIAS_LOFF_SENS, BIAS_STAT] : Power-down reference buffer, no bias
36
- 0x00, # LOFF [00] [COMP_TH[2:0], 0, ILEAD_OFF[1:0], FLEAD_OFF[1:0]] : No lead-off
37
- # Channel-Specific Settings:
38
- 0x61, # CH1SET [61] [PD1, GAIN1[2:0], SRB2, MUX1[2:0]] : Channel 1 active, 24 gain, no SRB2 & input shorted
39
- 0x61, # CH2SET [61] [PD2, GAIN2[2:0], SRB2, MUX2[2:0]] : Channel 2 active, 24 gain, no SRB2 & input shorted
40
- 0x61, # CH3SET [61] [PD3, GAIN3[2:0], SRB2, MUX3[2:0]] : Channel 3 active, 24 gain, no SRB2 & input shorted
41
- 0x61, # CH4SET [61] [PD4, GAIN4[2:0], SRB2, MUX4[2:0]] : Channel 4 active, 24 gain, no SRB2 & input shorted
42
- 0x61, # CH5SET [61] [PD5, GAIN5[2:0], SRB2, MUX5[2:0]] : Channel 5 active, 24 gain, no SRB2 & input shorted
43
- 0x61, # CH6SET [61] [PD6, GAIN6[2:0], SRB2, MUX6[2:0]] : Channel 6 active, 24 gain, no SRB2 & input shorted
44
- 0x61, # CH7SET [61] [PD7, GAIN7[2:0], SRB2, MUX7[2:0]] : Channel 7 active, 24 gain, no SRB2 & input shorted
45
- 0x61, # CH8SET [61] [PD8, GAIN8[2:0], SRB2, MUX8[2:0]] : Channel 8 active, 24 gain, no SRB2 & input shorted
46
- 0x00, # BIAS_SENSP [00] [BIASP8, BIASP7, BIASP6, BIASP5, BIASP4, BIASP3, BIASP2, BIASP1] : No bias
47
- 0x00, # BIAS_SENSN [00] [BIASN8, BIASN7, BIASN6, BIASN5, BIASN4, BIASN3, BIASN2, BIASN1] No bias
48
- 0x00, # LOFF_SENSP [00] [LOFFP8, LOFFP7, LOFFP6, LOFFP5, LOFFP4, LOFFP3, LOFFP2, LOFFP1] : No lead-off
49
- 0x00, # LOFF_SENSN [00] [LOFFM8, LOFFM7, LOFFM6, LOFFM5, LOFFM4, LOFFM3, LOFFM2, LOFFM1] : No lead-off
50
- 0x00, # LOFF_FLIP [00] [LOFF_FLIP8, LOFF_FLIP7, LOFF_FLIP6, LOFF_FLIP5, LOFF_FLIP4, LOFF_FLIP3, LOFF_FLIP2, LOFF_FLIP1] : No lead-off flip
51
- # Lead-Off Status Registers (Read-Only Registers):
52
- 0x00, # LOFF_STATP [00] [IN8P_OFF, IN7P_OFF, IN6P_OFF, IN5P_OFF, IN4P_OFF, IN3P_OFF, IN2P_OFF, IN1P_OFF] : Lead-off positive status (RO)
53
- 0x00, # LOFF_STATN [00] [IN8M_OFF, IN7M_OFF, IN6M_OFF, IN5M_OFF, IN4M_OFF, IN3M_OFF, IN2M_OFF, IN1M_OFF] : Laed-off negative status (RO)
54
- # GPIO and OTHER Registers:
55
- 0x0F, # GPIO [0F] [GPIOD[4:1], GPIOC[4:1]] : All GPIOs as inputs
56
- 0x00, # MISC1 [00] [0, 0, SRB1, 0, 0, 0, 0, 0] : Disable SRBM
57
- 0x00, # MISC2 [00] [00] : Unused
58
- 0x00, # CONFIG4 [00] [0, 0, 0, 0, SINGLE_SHOT, 0, PD_LOFF_COMP(bar), 0] : Single-shot, lead-off comparator disabled
59
- ]
60
-
61
- FRONTEND_CONFIG = [
62
- 0x3E, # ID (RO)
63
- 0x95, # CONFIG1 [95] [1, DAISY_EN(bar), CLK_EN, 1, 0, DR[2:0]] : Datarate = 500 SPS
64
- 0xD0, # CONFIG2 [C0] [1, 1, 0, INT_CAL, 0, CAL_AMP0, CAL_FREQ[1:0]]
65
- 0xFC, # CONFIG3 [E0] [PD_REFBUF(bar), 1, 1, BIAS_MEAS, BIASREF_INT, PD_BIAS(bar), BIAS_LOFF_SENS, BIAS_STAT] : Power-down reference buffer, no bias
66
- 0x00, # No lead-off
67
- 0x62, # CH1SET [60] [PD1, GAIN1[2:0], SRB2, MUX1[2:0]] set to measure BIAS signal
68
- 0x60, # CH2SET
69
- 0x60, # CH3SET
70
- 0x60, # CH4SET
71
- 0x60, # CH5SET
72
- 0x60, # CH6SET
73
- 0x60, # CH7SET
74
- 0x60, # CH8SET
75
- 0x00, # BIAS_SENSP 00
76
- 0x00, # BIAS_SENSN 00
77
- 0x00, # LOFF_SENSP Lead-off on all positive pins?
78
- 0x00, # LOFF_SENSN Lead-off on all negative pins?
79
- 0x00, # Normal lead-off
80
- 0x00, # Lead-off positive status (RO)
81
- 0x00, # Lead-off negative status (RO)
82
- 0x00, # All GPIOs as output ?
83
- 0x20, # Enable SRB1
84
- ]
85
-
86
-
87
- LEADOFF_CONFIG = [
88
- 0x3E, # ID (RO)
89
- 0x95, # CONFIG1 [95] [1, DAISY_EN(bar), CLK_EN, 1, 0, DR[2:0]] : Datarate = 500 SPS
90
- 0xC0, # CONFIG2 [C0] [1, 1, 0, INT_CAL, 0, CAL_AMP0, CAL_FREQ[1:0]]
91
- 0xFC, # CONFIG3 [E0] [PD_REFBUF(bar), 1, 1, BIAS_MEAS, BIASREF_INT, PD_BIAS(bar), BIAS_LOFF_SENS, BIAS_STAT] : Power-down reference buffer, no bias
92
- 0x00, # No lead-off
93
- 0x60, # CH1SET [60] [PD1, GAIN1[2:0], SRB2, MUX1[2:0]] set to measure BIAS signal
94
- 0x60, # CH2SET
95
- 0x60, # CH3SET
96
- 0x60, # CH4SET
97
- 0x60, # CH5SET
98
- 0x60, # CH6SET
99
- 0x60, # CH7SET
100
- 0x60, # CH8SET
101
- 0x00, # BIAS_SENSP 00
102
- 0x00, # BIAS_SENSN 00
103
- 0xFF, # LOFF_SENSP Lead-off on all positive pins?
104
- 0xFF, # LOFF_SENSN Lead-off on all negative pins?
105
- 0x00, # Normal lead-off
106
- 0x00, # Lead-off positive status (RO)
107
- 0x00, # Lead-off negative status (RO)
108
- 0x00, # All GPIOs as output ?
109
- 0x20, # Enable SRB1
110
- 0x00,
111
- 0x02,
112
- ]
113
-
114
- EDF_PATH = Path.home() / 'workspace' / 'edf_recording'
115
-
116
-
117
- def to_ads_frequency(frequency):
118
- possible_datarates = [250, 500, 1000, 2000, 4000, 8000, 16000]
119
- dr = 16000
120
- for i in possible_datarates:
121
- if i >= frequency:
122
- dr = i
123
- break
124
- return dr
125
-
126
- def mod_config(config, datarate, channel_modes):
127
-
128
- # datarate:
129
-
130
- possible_datarates = [(250, 0x06),
131
- (500, 0x05),
132
- (1000, 0x04),
133
- (2000, 0x03),
134
- (4000, 0x02),
135
- (8000, 0x01),
136
- (16000, 0x00)]
137
- mod_dr = 0x00
138
- for i, j in possible_datarates:
139
- if i >= datarate:
140
- mod_dr = j
141
- break
142
-
143
- new_cf1 = config[1] & 0xF8
144
- new_cf1 = new_cf1 | mod_dr
145
- config[1] = new_cf1
146
-
147
- # bias:
148
- assert len(channel_modes) == 7
149
- config[13] = 0x00 # clear BIAS_SENSP
150
- config[14] = 0x00 # clear BIAS_SENSN
151
- for chan_i, chan_mode in enumerate(channel_modes):
152
- n = 6 + chan_i
153
- mod = config[n] & 0x78 # clear PDn and MUX[2:0]
154
- if chan_mode == 'simple':
155
- # If channel is activated, we send the channel's output to the BIAS mechanism
156
- bit_i = 1 << chan_i + 1
157
- config[13] = config[13] | bit_i
158
- config[14] = config[14] | bit_i
159
- elif chan_mode == 'disabled':
160
- mod = mod | 0x81 # PDn = 1 and input shorted (001)
161
- else:
162
- assert False, f"Wrong key: {chan_mode}."
163
- config[n] = mod
164
- for n, c in enumerate(config): # print ADS1299 configuration registers
165
- print(f"config[{n}]:\t{c:08b}\t({hex(c)})")
166
- return config
167
-
168
-
169
- def filter_24(value):
170
- return (value * 4.5) / (2**23 - 1) / 24.0 * 1e6 # 23 because 1 bit is lost for sign
171
-
172
-
173
- def filter_2scomplement_np(value):
174
- return np.where((value & (1 << 23)) != 0, value - (1 << 24), value)
175
-
176
-
177
- def filter_np(value):
178
- return filter_24(filter_2scomplement_np(value))
179
-
180
-
181
- def shift_numpy(arr, num, fill_value=np.nan):
182
- result = np.empty_like(arr)
183
- if num > 0:
184
- result[:num] = fill_value
185
- result[num:] = arr[:-num]
186
- elif num < 0:
187
- result[num:] = fill_value
188
- result[:num] = arr[-num:]
189
- else:
190
- result[:] = arr
191
- return result
192
-
193
-
194
- class FIR:
195
- def __init__(self, nb_channels, coefficients, buffer=None):
196
-
197
- self.coefficients = np.expand_dims(np.array(coefficients), axis=1)
198
- self.taps = len(self.coefficients)
199
- self.nb_channels = nb_channels
200
- self.buffer = np.array(z) if buffer is not None else np.zeros((self.taps, self.nb_channels))
201
-
202
- def filter(self, x):
203
- self.buffer = shift_numpy(self.buffer, 1, x)
204
- filtered = np.sum(self.buffer * self.coefficients, axis=0)
205
- return filtered
206
-
207
-
208
- class FilterPipeline:
209
- def __init__(self,
210
- nb_channels,
211
- sampling_rate,
212
- power_line_fq=60,
213
- use_custom_fir=False,
214
- custom_fir_order=20,
215
- custom_fir_cutoff=30,
216
- alpha_avg=0.1,
217
- alpha_std=0.001,
218
- epsilon=0.000001,
219
- filter_args=[]):
220
- if len(filter_args) > 0:
221
- use_fir, use_notch, use_std = filter_args
222
- else:
223
- use_fir=True,
224
- use_notch=True,
225
- use_std=True
226
- self.use_fir = use_fir
227
- self.use_notch = use_notch
228
- self.use_std = use_std
229
- self.nb_channels = nb_channels
230
- assert power_line_fq in [50, 60], f"The only supported power line frequencies are 50 Hz and 60 Hz"
231
- if power_line_fq == 60:
232
- self.notch_coeff1 = -0.12478308884588535
233
- self.notch_coeff2 = 0.98729186796473023
234
- self.notch_coeff3 = 0.99364593398236511
235
- self.notch_coeff4 = -0.12478308884588535
236
- self.notch_coeff5 = 0.99364593398236511
237
- else:
238
- self.notch_coeff1 = -0.61410695998423581
239
- self.notch_coeff2 = 0.98729186796473023
240
- self.notch_coeff3 = 0.99364593398236511
241
- self.notch_coeff4 = -0.61410695998423581
242
- self.notch_coeff5 = 0.99364593398236511
243
- self.dfs = [np.zeros(self.nb_channels), np.zeros(self.nb_channels)]
244
-
245
- self.moving_average = None
246
- self.moving_variance = np.zeros(self.nb_channels)
247
- self.ALPHA_AVG = alpha_avg
248
- self.ALPHA_STD = alpha_std
249
- self.EPSILON = epsilon
250
-
251
- if use_custom_fir:
252
- self.fir_coef = firwin(numtaps=custom_fir_order+1, cutoff=custom_fir_cutoff, fs=sampling_rate)
253
- else:
254
- self.fir_coef = [
255
- 0.001623780150148094927192721215192250384,
256
- 0.014988684599373741992978104065059596905,
257
- 0.021287595318265635502275046064823982306,
258
- 0.007349500393709578957568417933998716762,
259
- -0.025127515717112181709014251396183681209,
260
- -0.052210507359822452833064687638398027048,
261
- -0.039273839505489904766477593511808663607,
262
- 0.033021568427940004020193498490698402748,
263
- 0.147606943281569008563636202779889572412,
264
- 0.254000252034505602516389899392379447818,
265
- 0.297330876398883392486283128164359368384,
266
- 0.254000252034505602516389899392379447818,
267
- 0.147606943281569008563636202779889572412,
268
- 0.033021568427940004020193498490698402748,
269
- -0.039273839505489904766477593511808663607,
270
- -0.052210507359822452833064687638398027048,
271
- -0.025127515717112181709014251396183681209,
272
- 0.007349500393709578957568417933998716762,
273
- 0.021287595318265635502275046064823982306,
274
- 0.014988684599373741992978104065059596905,
275
- 0.001623780150148094927192721215192250384]
276
- self.fir = FIR(self.nb_channels, self.fir_coef)
277
-
278
- def filter(self, value):
279
- """
280
- value: a numpy array of shape (data series, channels)
281
- """
282
- for i, x in enumerate(value): # loop over the data series
283
- # FIR:
284
- if self.use_fir:
285
- x = self.fir.filter(x)
286
- # notch:
287
- if self.use_notch:
288
- denAccum = (x - self.notch_coeff1 * self.dfs[0]) - self.notch_coeff2 * self.dfs[1]
289
- x = (self.notch_coeff3 * denAccum + self.notch_coeff4 * self.dfs[0]) + self.notch_coeff5 * self.dfs[1]
290
- self.dfs[1] = self.dfs[0]
291
- self.dfs[0] = denAccum
292
- # standardization:
293
- if self.use_std:
294
- if self.moving_average is not None:
295
- delta = x - self.moving_average
296
- self.moving_average = self.moving_average + self.ALPHA_AVG * delta
297
- self.moving_variance = (1 - self.ALPHA_STD) * (self.moving_variance + self.ALPHA_STD * delta**2)
298
- moving_std = np.sqrt(self.moving_variance)
299
- x = (x - self.moving_average) / (moving_std + self.EPSILON)
300
- else:
301
- self.moving_average = x
302
- value[i] = x
303
- return value
304
-
305
-
306
- class LiveDisplay():
307
- def __init__(self, channel_names, window_len=100):
308
- self.datapoint_dim = len(channel_names)
309
- self.history = []
310
- self.pp = ProgressPlot(plot_names=channel_names, max_window_len=window_len)
311
- self.matplotlib = False
312
-
313
- def add_datapoints(self, datapoints):
314
- """
315
- Adds 8 lists of datapoints to the plot
316
-
317
- Args:
318
- datapoints: list of 8 lists of floats (or list of 8 floats)
319
- """
320
- if self.matplotlib:
321
- import matplotlib.pyplot as plt
322
- disp_list = []
323
- for datapoint in datapoints:
324
- d = [[elt] for elt in datapoint]
325
- disp_list.append(d)
326
-
327
- if self.matplotlib:
328
- self.history += d[1]
329
-
330
- if not self.matplotlib:
331
- self.pp.update_with_datapoints(disp_list)
332
- elif len(self.history) == 1000:
333
- plt.plot(self.history)
334
- plt.show()
335
- self.history = []
336
-
337
- def add_datapoint(self, datapoint):
338
- disp_list = [[elt] for elt in datapoint]
339
- self.pp.update(disp_list)
340
-
341
-
342
- def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time_msg_in, channel_states):
343
  """
344
  Args:
345
  p_data_o: multiprocessing.Pipe: captured datapoints are put here
@@ -434,69 +120,8 @@ def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time
434
  p_msg_io.send('STOP')
435
  p_msg_io.close()
436
  p_data_o.close()
437
-
438
-
439
- class DummyAlsaMixer:
440
- def __init__(self):
441
- self.volume = 50
442
-
443
- def getvolume(self):
444
- return [self.volume]
445
 
446
- def setvolume(self, volume):
447
- self.volume = volume
448
-
449
 
450
- class UpStateDelayer:
451
- def __init__(self, sample_freq, spindle_freq, peak):
452
- '''
453
- args:
454
- buffer_size: int -> Size of desired buffer in length
455
- sample_freq: int -> Sampling frequency of signal in Hz
456
- '''
457
- # Get number of timesteps for a whole spindle
458
- self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
459
- self.sample_freq = sample_freq
460
- self.buffer_size = 1.5 * self.spindle_timesteps
461
- self.peak = peak
462
- self.buffer = []
463
-
464
- def add_point(self, point):
465
- '''
466
- Adds a point to the buffer to be able to keep track of peaks
467
- '''
468
- self.buffer.append(point)
469
- if len(self.buffer) > self.buffer_size:
470
- self.buffer.pop(0)
471
-
472
- def stimulate(self):
473
- # Calculate how far away is last peak
474
- last_peak = -1
475
- count = 0
476
- for idx, point in reversed(list(enumerate(self.buffer))):
477
- if self.peak:
478
- try:
479
- sup = point >= self.buffer[idx+1]
480
- except IndexError:
481
- sup = False
482
- try:
483
- inf = point >= self.buffer[idx-1]
484
- except IndexError:
485
- inf = False
486
- else:
487
- try:
488
- sup = point <= self.buffer[idx+1]
489
- except IndexError:
490
- sup = False
491
- try:
492
- inf = point <= self.buffer[idx-1]
493
- except IndexError:
494
- inf = False
495
- if sup and inf:
496
- last_peak = count
497
- return self.spindle_timesteps - last_peak
498
- count += 1
499
- return -1
500
 
501
  class Capture:
502
  def __init__(self, detector_cls=None, stimulator_cls=None):
@@ -521,13 +146,10 @@ class Capture:
521
  self.threshold = 0.5
522
  self.lsl = False
523
  self.display = False
 
524
  self.python_clock = True
525
  self.edf_writer = None
526
  self.edf_buffer = []
527
- self.nb_signals = 8
528
- self.samples_per_datarecord_array = self.frequency
529
- self.physical_max = 5
530
- self.physical_min = -5
531
  self.signal_labels = ['Common Mode', 'ch2', 'ch3', 'ch4', 'ch5', 'ch6', 'ch7', 'ch8']
532
  self._lock_msg_out = Lock()
533
  self._msg_out = None
@@ -546,22 +168,25 @@ class Capture:
546
  self._pause_detect_lock = Lock()
547
  self._pause_detect = True
548
 
549
- try:
550
- mixers = alsaaudio.mixers()
551
- if len(mixers) <= 0:
 
 
 
 
 
 
 
 
 
552
  warnings.warn(f"No ALSA mixer found.")
553
  self.mixer = DummyAlsaMixer()
554
- elif 'PCM' in mixers:
555
- self.mixer = alsaaudio.Mixer(control='PCM')
556
- else:
557
- warnings.warn(f"Could not find mixer PCM, using {mixers[0]} instead.")
558
- self.mixer = alsaaudio.Mixer(control=mixers[0])
559
- except ALSAAudioError as e:
560
- warnings.warn(f"No ALSA mixer found.")
561
  self.mixer = DummyAlsaMixer()
562
-
563
- self.volume = self.mixer.getvolume()[0] # we will set the same volume on all channels
564
-
565
 
566
  # widgets ===============================
567
 
@@ -678,7 +303,7 @@ class Capture:
678
  )
679
 
680
  self.b_clock = widgets.ToggleButtons(
681
- options=['Coral', 'ADS'],
682
  description='Clock:',
683
  disabled=False,
684
  button_style='', # 'success', 'info', 'warning', 'danger' or ''
@@ -694,6 +319,15 @@ class Capture:
694
  tooltips=['North America 60 Hz',
695
  'Europe 50 Hz'],
696
  )
 
 
 
 
 
 
 
 
 
697
 
698
  self.b_custom_fir = widgets.ToggleButtons(
699
  options=['Default', 'Custom'],
@@ -890,6 +524,7 @@ class Capture:
890
  self.b_spindle_mode.observe(self.on_b_spindle_mode, 'value')
891
  self.b_spindle_freq.observe(self.on_b_spindle_freq, 'value')
892
  self.b_power_line.observe(self.on_b_power_line, 'value')
 
893
  self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
894
  self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
895
  self.b_custom_fir_cutoff.observe(self.on_b_custom_fir_cutoff, 'value')
@@ -900,9 +535,10 @@ class Capture:
900
  self.b_test_stimulus.on_click(self.on_b_test_stimulus)
901
  self.b_test_impedance.on_click(self.on_b_test_impedance)
902
  self.b_pause.observe(self.on_b_pause, 'value')
903
-
904
  self.display_buttons()
905
 
 
906
  def __del__(self):
907
  self.b_capture.close()
908
 
@@ -912,6 +548,7 @@ class Capture:
912
  self.b_frequency,
913
  self.b_duration,
914
  self.b_filename,
 
915
  self.b_power_line,
916
  self.b_clock,
917
  widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
@@ -941,6 +578,7 @@ class Capture:
941
  self.b_radio_ch7.disabled = False
942
  self.b_radio_ch8.disabled = False
943
  self.b_power_line.disabled = False
 
944
  self.b_channel_detect.disabled = False
945
  self.b_spindle_freq.disabled = False
946
  self.b_spindle_mode.disabled = False
@@ -981,6 +619,7 @@ class Capture:
981
  self.b_channel_detect.disabled = True
982
  self.b_spindle_freq.disabled = True
983
  self.b_spindle_mode.disabled = True
 
984
  self.b_power_line.disabled = True
985
  self.b_polyak_mean.disabled = True
986
  self.b_polyak_std.disabled = True
@@ -1067,8 +706,6 @@ class Capture:
1067
  self._t_capture = None
1068
  self.enable_buttons()
1069
 
1070
-
1071
-
1072
  def on_b_custom_fir(self, value):
1073
  val = value['new']
1074
  if val == 'Default':
@@ -1083,13 +720,20 @@ class Capture:
1083
  self.python_clock = True
1084
  elif val == 'ADS':
1085
  self.python_clock = False
1086
-
 
 
 
 
 
 
 
1087
  def on_b_power_line(self, value):
1088
  val = value['new']
1089
  if val == '60 Hz':
1090
  self.power_line = 60
1091
  elif val == '50 Hz':
1092
- self.python_clock = 50
1093
 
1094
  def on_b_frequency(self, value):
1095
  val = value['new']
@@ -1243,42 +887,6 @@ class Capture:
1243
  elif val == 'Paused':
1244
  with self._pause_detect_lock:
1245
  self._pause_detect = True
1246
-
1247
- def open_recording_file(self):
1248
- nb_signals = self.nb_signals
1249
- samples_per_datarecord_array = self.samples_per_datarecord_array
1250
- physical_max = self.physical_max
1251
- physical_min = self.physical_min
1252
- signal_labels = self.signal_labels
1253
-
1254
- print(f"Will store edf recording in {self.filename}")
1255
-
1256
- self.edf_writer = EDFwriter(p_path=str(self.filename),
1257
- f_file_type=EDFwriter.EDFLIB_FILETYPE_EDFPLUS,
1258
- number_of_signals=nb_signals)
1259
-
1260
- for signal in range(nb_signals):
1261
- assert self.edf_writer.setSampleFrequency(signal, samples_per_datarecord_array) == 0
1262
- assert self.edf_writer.setPhysicalMaximum(signal, physical_max) == 0
1263
- assert self.edf_writer.setPhysicalMinimum(signal, physical_min) == 0
1264
- assert self.edf_writer.setDigitalMaximum(signal, 32767) == 0
1265
- assert self.edf_writer.setDigitalMinimum(signal, -32768) == 0
1266
- assert self.edf_writer.setSignalLabel(signal, signal_labels[signal]) == 0
1267
- assert self.edf_writer.setPhysicalDimension(signal, 'V') == 0
1268
-
1269
- def close_recording_file(self):
1270
- assert self.edf_writer.close() == 0
1271
-
1272
- def add_recording_data(self, data):
1273
- self.edf_buffer += data
1274
- if len(self.edf_buffer) >= self.samples_per_datarecord_array:
1275
- datarecord_array = self.edf_buffer[:self.samples_per_datarecord_array]
1276
- self.edf_buffer = self.edf_buffer[self.samples_per_datarecord_array:]
1277
- datarecord_array = np.array(datarecord_array).transpose()
1278
- assert len(datarecord_array) == self.nb_signals, f"len(data)={len(data)}!={self.nb_signals}"
1279
- for d in datarecord_array:
1280
- assert len(d) == self.samples_per_datarecord_array, f"{len(d)}!={self.samples_per_datarecord_array}"
1281
- assert self.edf_writer.writeSamples(d) == 0
1282
 
1283
  def start_capture(self,
1284
  filter,
@@ -1292,15 +900,17 @@ class Capture:
1292
  viz,
1293
  width,
1294
  python_clock):
1295
- if self.__capture_on:
1296
- warnings.warn("Capture is already ongoing, ignoring command.")
1297
- return
1298
- else:
1299
- self.__capture_on = True
1300
- p_msg_io, p_msg_io_2 = mp.Pipe()
1301
- p_data_i, p_data_o = mp.Pipe(duplex=False)
1302
- SAMPLE_TIME = 1 / self.frequency
1303
 
 
 
 
 
 
 
 
 
 
 
1304
  if filter:
1305
  fp = FilterPipeline(nb_channels=8,
1306
  sampling_rate=self.frequency,
@@ -1312,28 +922,37 @@ class Capture:
1312
  alpha_std=self.polyak_std,
1313
  epsilon=self.epsilon,
1314
  filter_args=filter_args)
1315
-
 
1316
  detector = detector_cls(threshold, channel=channel) if detector_cls is not None else None
1317
  stimulator = stimulator_cls() if stimulator_cls is not None else None
1318
 
1319
- self._p_capture = mp.Process(target=_capture_process,
1320
- args=(p_data_o,
1321
- p_msg_io_2,
1322
- self.duration,
1323
- self.frequency,
1324
- python_clock,
1325
- 1.0,
1326
- self.channel_states)
1327
- )
1328
- self._p_capture.start()
1329
- print(f"PID capture: {self._p_capture.pid}")
 
 
 
 
 
1330
 
 
1331
  if viz:
1332
  live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
1333
 
 
1334
  if record:
1335
- self.open_recording_file()
1336
 
 
1337
  if lsl:
1338
  from pylsl import StreamInfo, StreamOutlet
1339
  lsl_info = StreamInfo(name='Portiloop Filtered',
@@ -1353,53 +972,71 @@ class Capture:
1353
 
1354
  buffer = []
1355
 
 
1356
  if not self.spindle_detection_mode == 'Fast' and stimulator is not None:
1357
- stimulation_delayer = UpStateDelayer(self.frequency, self.spindle_freq, self.spindle_detection_mode == 'Peak')
1358
  stimulator.add_delayer(stimulation_delayer)
1359
  else:
1360
  stimulation_delayer = None
1361
 
 
1362
  while True:
1363
- with self._lock_msg_out:
1364
- if self._msg_out is not None:
1365
- p_msg_io.send(self._msg_out)
1366
- self._msg_out = None
1367
- if p_msg_io.poll():
1368
- mess = p_msg_io.recv()
1369
- if mess == 'STOP':
1370
- break
1371
- elif mess[0] == 'PRT':
1372
- print(mess[1])
 
 
 
 
 
 
 
 
 
 
 
1373
 
1374
- # retrieve all data points from p_data and put them in a list of np.array:
1375
- point = None
1376
- if p_data_i.poll(timeout=SAMPLE_TIME):
1377
- point = p_data_i.recv()
1378
- else:
1379
- continue
1380
-
1381
- n_array = np.array([point])
1382
- n_array_raw = filter_np(n_array)
1383
 
 
1384
  if filter:
1385
  n_array = fp.filter(deepcopy(n_array_raw))
1386
  else:
1387
- n_array = n_array_raw
1388
-
 
1389
  filtered_point = n_array.tolist()
1390
 
 
1391
  if lsl:
1392
  raw_point = n_array_raw.tolist()
1393
  lsl_outlet_raw.push_sample(raw_point[-1])
1394
  lsl_outlet.push_sample(filtered_point[-1])
1395
-
 
1396
  if stimulation_delayer is not None:
1397
- stimulation_delayer.add_point(filtered_point[channel-1])
1398
 
 
1399
  with self._pause_detect_lock:
1400
  pause = self._pause_detect
 
 
1401
  if detector is not None and not pause:
 
1402
  detection_signal = detector.detect(filtered_point)
 
 
1403
  if stimulator is not None:
1404
  stimulator.stimulate(detection_signal)
1405
  with self._test_stimulus_lock:
@@ -1407,35 +1044,36 @@ class Capture:
1407
  self._test_stimulus = False
1408
  if test_stimulus:
1409
  stimulator.test_stimulus()
 
 
 
1410
 
 
1411
  buffer += filtered_point
1412
  if len(buffer) >= 50:
1413
-
1414
  if viz:
1415
  live_disp.add_datapoints(buffer)
1416
-
1417
  if record:
1418
- self.add_recording_data(buffer)
1419
-
1420
  buffer = []
1421
 
1422
- # empty pipes
1423
- while True:
1424
- if p_data_i.poll():
1425
- _ = p_data_i.recv()
1426
- elif p_msg_io.poll():
1427
- _ = p_msg_io.recv()
1428
- else:
1429
- break
 
1430
 
1431
- p_data_i.close()
1432
- p_msg_io.close()
 
 
1433
 
1434
  if record:
1435
- self.close_recording_file()
1436
-
1437
- self._p_capture.join()
1438
- self.__capture_on = False
1439
 
1440
 
1441
  if __name__ == "__main__":
 
1
+
 
2
 
3
  from time import sleep
4
  import time
5
  import numpy as np
 
6
  from copy import deepcopy
7
+ from datetime import datetime
 
8
  import multiprocessing as mp
9
  import warnings
 
10
  from threading import Thread, Lock
11
+ from portiloop.src import ADS
12
+
13
+ if ADS:
14
+ import alsaaudio
15
+ from portiloop.src.hardware.frontend import Frontend
16
+ from portiloop.src.hardware.leds import LEDs, Color
17
 
18
+ from portiloop.src.stimulation import UpStateDelayer
 
19
 
 
 
 
20
 
21
+ from portiloop.src.processing import FilterPipeline, int_to_float
22
+ from portiloop.src.config import mod_config, LEADOFF_CONFIG, FRONTEND_CONFIG, to_ads_frequency
23
+ from portiloop.src.utils import FileReader, LiveDisplay, DummyAlsaMixer, EDFRecorder, EDF_PATH
24
  from IPython.display import clear_output, display
25
  import ipywidgets as widgets
26
 
27
 
28
+ def capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time_msg_in, channel_states):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
  Args:
31
  p_data_o: multiprocessing.Pipe: captured datapoints are put here
 
120
  p_msg_io.send('STOP')
121
  p_msg_io.close()
122
  p_data_o.close()
 
 
 
 
 
 
 
 
123
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  class Capture:
127
  def __init__(self, detector_cls=None, stimulator_cls=None):
 
146
  self.threshold = 0.5
147
  self.lsl = False
148
  self.display = False
149
+ self.signal_input = "ADS"
150
  self.python_clock = True
151
  self.edf_writer = None
152
  self.edf_buffer = []
 
 
 
 
153
  self.signal_labels = ['Common Mode', 'ch2', 'ch3', 'ch4', 'ch5', 'ch6', 'ch7', 'ch8']
154
  self._lock_msg_out = Lock()
155
  self._msg_out = None
 
168
  self._pause_detect_lock = Lock()
169
  self._pause_detect = True
170
 
171
+ if ADS:
172
+ try:
173
+ mixers = alsaaudio.mixers()
174
+ if len(mixers) <= 0:
175
+ warnings.warn(f"No ALSA mixer found.")
176
+ self.mixer = DummyAlsaMixer()
177
+ elif 'PCM' in mixers:
178
+ self.mixer = alsaaudio.Mixer(control='PCM')
179
+ else:
180
+ warnings.warn(f"Could not find mixer PCM, using {mixers[0]} instead.")
181
+ self.mixer = alsaaudio.Mixer(control=mixers[0])
182
+ except ALSAAudioError as e:
183
  warnings.warn(f"No ALSA mixer found.")
184
  self.mixer = DummyAlsaMixer()
185
+
186
+ self.volume = self.mixer.getvolume()[0] # we will set the same volume on all channels
187
+ else:
 
 
 
 
188
  self.mixer = DummyAlsaMixer()
189
+ self.volume = self.mixer.getvolume()[0]
 
 
190
 
191
  # widgets ===============================
192
 
 
303
  )
304
 
305
  self.b_clock = widgets.ToggleButtons(
306
+ options=['ADS', 'Coral'],
307
  description='Clock:',
308
  disabled=False,
309
  button_style='', # 'success', 'info', 'warning', 'danger' or ''
 
319
  tooltips=['North America 60 Hz',
320
  'Europe 50 Hz'],
321
  )
322
+
323
+ self.b_signal_input = widgets.ToggleButtons(
324
+ options=['ADS', 'File'],
325
+ description='Signal Input:',
326
+ disabled=False,
327
+ button_style='', # 'success', 'info', 'warning', 'danger' or ''
328
+ tooltips=['Read data from ADS.',
329
+ 'Read data from file.'],
330
+ )
331
 
332
  self.b_custom_fir = widgets.ToggleButtons(
333
  options=['Default', 'Custom'],
 
524
  self.b_spindle_mode.observe(self.on_b_spindle_mode, 'value')
525
  self.b_spindle_freq.observe(self.on_b_spindle_freq, 'value')
526
  self.b_power_line.observe(self.on_b_power_line, 'value')
527
+ self.b_signal_input.observe(self.on_b_power_line, 'value')
528
  self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
529
  self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
530
  self.b_custom_fir_cutoff.observe(self.on_b_custom_fir_cutoff, 'value')
 
535
  self.b_test_stimulus.on_click(self.on_b_test_stimulus)
536
  self.b_test_impedance.on_click(self.on_b_test_impedance)
537
  self.b_pause.observe(self.on_b_pause, 'value')
538
+
539
  self.display_buttons()
540
 
541
+
542
  def __del__(self):
543
  self.b_capture.close()
544
 
 
548
  self.b_frequency,
549
  self.b_duration,
550
  self.b_filename,
551
+ self.b_signal_input,
552
  self.b_power_line,
553
  self.b_clock,
554
  widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
 
578
  self.b_radio_ch7.disabled = False
579
  self.b_radio_ch8.disabled = False
580
  self.b_power_line.disabled = False
581
+ self.b_signal_input.disabled = False
582
  self.b_channel_detect.disabled = False
583
  self.b_spindle_freq.disabled = False
584
  self.b_spindle_mode.disabled = False
 
619
  self.b_channel_detect.disabled = True
620
  self.b_spindle_freq.disabled = True
621
  self.b_spindle_mode.disabled = True
622
+ self.b_signal_input.disabled = True
623
  self.b_power_line.disabled = True
624
  self.b_polyak_mean.disabled = True
625
  self.b_polyak_std.disabled = True
 
706
  self._t_capture = None
707
  self.enable_buttons()
708
 
 
 
709
  def on_b_custom_fir(self, value):
710
  val = value['new']
711
  if val == 'Default':
 
720
  self.python_clock = True
721
  elif val == 'ADS':
722
  self.python_clock = False
723
+
724
+ def on_b_signal_input(self, value):
725
+ val = value['new']
726
+ if val == "ADS":
727
+ self.signal_input = "ADS"
728
+ elif val == "File":
729
+ self.signal_input = "File"
730
+
731
  def on_b_power_line(self, value):
732
  val = value['new']
733
  if val == '60 Hz':
734
  self.power_line = 60
735
  elif val == '50 Hz':
736
+ self.power_line = 50
737
 
738
  def on_b_frequency(self, value):
739
  val = value['new']
 
887
  elif val == 'Paused':
888
  with self._pause_detect_lock:
889
  self._pause_detect = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
 
891
  def start_capture(self,
892
  filter,
 
900
  viz,
901
  width,
902
  python_clock):
 
 
 
 
 
 
 
 
903
 
904
+ if self.signal_input == "ADS":
905
+ if self.__capture_on:
906
+ warnings.warn("Capture is already ongoing, ignoring command.")
907
+ return
908
+ else:
909
+ self.__capture_on = True
910
+ p_msg_io, p_msg_io_2 = mp.Pipe()
911
+ p_data_i, p_data_o = mp.Pipe(duplex=False)
912
+
913
+ # Initialize filtering pipeline
914
  if filter:
915
  fp = FilterPipeline(nb_channels=8,
916
  sampling_rate=self.frequency,
 
922
  alpha_std=self.polyak_std,
923
  epsilon=self.epsilon,
924
  filter_args=filter_args)
925
+
926
+ # Initialize detector and stimulator
927
  detector = detector_cls(threshold, channel=channel) if detector_cls is not None else None
928
  stimulator = stimulator_cls() if stimulator_cls is not None else None
929
 
930
+ # Launch the capture process
931
+ if self.signal_input == "ADS":
932
+ self._p_capture = mp.Process(target=capture_process,
933
+ args=(p_data_o,
934
+ p_msg_io_2,
935
+ self.duration,
936
+ self.frequency,
937
+ python_clock,
938
+ 1.0,
939
+ self.channel_states)
940
+ )
941
+ self._p_capture.start()
942
+ print(f"PID capture: {self._p_capture.pid}")
943
+ else:
944
+ filename = "INSERT FILENAME" # TODO
945
+ file_reader = FileReader(filename)
946
 
947
+ # Initialize display if requested
948
  if viz:
949
  live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
950
 
951
+ # Initialize recording if requested
952
  if record:
953
+ recorder = EDFRecorder(self.signal_label)
954
 
955
+ # Initialize LSL to stream if requested
956
  if lsl:
957
  from pylsl import StreamInfo, StreamOutlet
958
  lsl_info = StreamInfo(name='Portiloop Filtered',
 
972
 
973
  buffer = []
974
 
975
+ # Initialize stimulation delayer if requested
976
  if not self.spindle_detection_mode == 'Fast' and stimulator is not None:
977
+ stimulation_delayer = UpStateDelayer(self.frequency, self.spindle_freq, self.spindle_detection_mode == 'Peak', time_to_buffer=0.1)
978
  stimulator.add_delayer(stimulation_delayer)
979
  else:
980
  stimulation_delayer = None
981
 
982
+ # Main capture loop
983
  while True:
984
+ if self.signal_input == "ADS":
985
+ # Send message in communication pipe if we have one
986
+ with self._lock_msg_out:
987
+ if self._msg_out is not None:
988
+ p_msg_io.send(self._msg_out)
989
+ self._msg_out = None
990
+
991
+ # Check if we have received a message in communication pipe
992
+ if p_msg_io.poll():
993
+ mess = p_msg_io.recv()
994
+ if mess == 'STOP':
995
+ break
996
+ elif mess[0] == 'PRT':
997
+ print(mess[1])
998
+
999
+ # Retrieve all data points from data pipe p_data
1000
+ point = None
1001
+ if p_data_i.poll(timeout=(1 / self.frequency)):
1002
+ point = p_data_i.recv()
1003
+ else:
1004
+ continue
1005
 
1006
+ # Convert point from int to corresponding value in microvolts
1007
+ n_array_raw = int_to_float(np.array([point]))
1008
+ elif self.signal_input == "File":
1009
+ n_array_raw, gt_stimulation = file_reader.get_point()
 
 
 
 
 
1010
 
1011
+ # Go through filtering pipeline
1012
  if filter:
1013
  n_array = fp.filter(deepcopy(n_array_raw))
1014
  else:
1015
+ n_array = deepcopy(n_array_raw)
1016
+
1017
+ # Contains the filtered point (if filtering is off, contains a copy of the raw point)
1018
  filtered_point = n_array.tolist()
1019
 
1020
+ # Send both raw and filtered points over LSL
1021
  if lsl:
1022
  raw_point = n_array_raw.tolist()
1023
  lsl_outlet_raw.push_sample(raw_point[-1])
1024
  lsl_outlet.push_sample(filtered_point[-1])
1025
+
1026
+ # Adds point to buffer for delayed stimulation
1027
  if stimulation_delayer is not None:
1028
+ stimulation_delayer.step(filtered_point[0][channel-1])
1029
 
1030
+ # Check if detection is on or off
1031
  with self._pause_detect_lock:
1032
  pause = self._pause_detect
1033
+
1034
+ # If detection is on
1035
  if detector is not None and not pause:
1036
+ # Detect using the latest point
1037
  detection_signal = detector.detect(filtered_point)
1038
+
1039
+ # Stimulate
1040
  if stimulator is not None:
1041
  stimulator.stimulate(detection_signal)
1042
  with self._test_stimulus_lock:
 
1044
  self._test_stimulus = False
1045
  if test_stimulus:
1046
  stimulator.test_stimulus()
1047
+
1048
+ if self.signal_input == "File" and gt_stimulation:
1049
+ stimulator.send_stimulation("GROUND_TRUTH_STIM", False)
1050
 
1051
+ # Add point to the buffer to send to viz and recorder
1052
  buffer += filtered_point
1053
  if len(buffer) >= 50:
 
1054
  if viz:
1055
  live_disp.add_datapoints(buffer)
 
1056
  if record:
1057
+ recorder.add_recording_data(buffer)
 
1058
  buffer = []
1059
 
1060
+ if self.signal_input == "ADS":
1061
+ # Empty pipes
1062
+ while True:
1063
+ if p_data_i.poll():
1064
+ _ = p_data_i.recv()
1065
+ elif p_msg_io.poll():
1066
+ _ = p_msg_io.recv()
1067
+ else:
1068
+ break
1069
 
1070
+ p_data_i.close()
1071
+ p_msg_io.close()
1072
+ self._p_capture.join()
1073
+ self.__capture_on = False
1074
 
1075
  if record:
1076
+ recorder.close_recording_file()
 
 
 
1077
 
1078
 
1079
  if __name__ == "__main__":
portiloop/src/config.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_FRONTEND_CONFIG = [
2
+ # nomenclature: name [default setting] [bits 7-0] : description
3
+ # Read only ID:
4
+ 0x3E, # ID [xx] [REV_ID[2:0], 1, DEV_ID[1:0], NU_CH[1:0]] : (RO)
5
+ # Global Settings Across Channels:
6
+ 0x96, # CONFIG1 [96] [1, DAISY_EN(bar), CLK_EN, 1, 0, DR[2:0]] : Datarate = 250 SPS
7
+ 0xC0, # CONFIG2 [C0] [1, 1, 0, INT_CAL, 0, CAL_AMP0, CAL_FREQ[1:0]] : No tests
8
+ 0x60, # CONFIG3 [60] [PD_REFBUF(bar), 1, 1, BIAS_MEAS, BIASREF_INT, PD_BIAS(bar), BIAS_LOFF_SENS, BIAS_STAT] : Power-down reference buffer, no bias
9
+ 0x00, # LOFF [00] [COMP_TH[2:0], 0, ILEAD_OFF[1:0], FLEAD_OFF[1:0]] : No lead-off
10
+ # Channel-Specific Settings:
11
+ 0x61, # CH1SET [61] [PD1, GAIN1[2:0], SRB2, MUX1[2:0]] : Channel 1 active, 24 gain, no SRB2 & input shorted
12
+ 0x61, # CH2SET [61] [PD2, GAIN2[2:0], SRB2, MUX2[2:0]] : Channel 2 active, 24 gain, no SRB2 & input shorted
13
+ 0x61, # CH3SET [61] [PD3, GAIN3[2:0], SRB2, MUX3[2:0]] : Channel 3 active, 24 gain, no SRB2 & input shorted
14
+ 0x61, # CH4SET [61] [PD4, GAIN4[2:0], SRB2, MUX4[2:0]] : Channel 4 active, 24 gain, no SRB2 & input shorted
15
+ 0x61, # CH5SET [61] [PD5, GAIN5[2:0], SRB2, MUX5[2:0]] : Channel 5 active, 24 gain, no SRB2 & input shorted
16
+ 0x61, # CH6SET [61] [PD6, GAIN6[2:0], SRB2, MUX6[2:0]] : Channel 6 active, 24 gain, no SRB2 & input shorted
17
+ 0x61, # CH7SET [61] [PD7, GAIN7[2:0], SRB2, MUX7[2:0]] : Channel 7 active, 24 gain, no SRB2 & input shorted
18
+ 0x61, # CH8SET [61] [PD8, GAIN8[2:0], SRB2, MUX8[2:0]] : Channel 8 active, 24 gain, no SRB2 & input shorted
19
+ 0x00, # BIAS_SENSP [00] [BIASP8, BIASP7, BIASP6, BIASP5, BIASP4, BIASP3, BIASP2, BIASP1] : No bias
20
+ 0x00, # BIAS_SENSN [00] [BIASN8, BIASN7, BIASN6, BIASN5, BIASN4, BIASN3, BIASN2, BIASN1] No bias
21
+ 0x00, # LOFF_SENSP [00] [LOFFP8, LOFFP7, LOFFP6, LOFFP5, LOFFP4, LOFFP3, LOFFP2, LOFFP1] : No lead-off
22
+ 0x00, # LOFF_SENSN [00] [LOFFM8, LOFFM7, LOFFM6, LOFFM5, LOFFM4, LOFFM3, LOFFM2, LOFFM1] : No lead-off
23
+ 0x00, # LOFF_FLIP [00] [LOFF_FLIP8, LOFF_FLIP7, LOFF_FLIP6, LOFF_FLIP5, LOFF_FLIP4, LOFF_FLIP3, LOFF_FLIP2, LOFF_FLIP1] : No lead-off flip
24
+ # Lead-Off Status Registers (Read-Only Registers):
25
+ 0x00, # LOFF_STATP [00] [IN8P_OFF, IN7P_OFF, IN6P_OFF, IN5P_OFF, IN4P_OFF, IN3P_OFF, IN2P_OFF, IN1P_OFF] : Lead-off positive status (RO)
26
+ 0x00, # LOFF_STATN [00] [IN8M_OFF, IN7M_OFF, IN6M_OFF, IN5M_OFF, IN4M_OFF, IN3M_OFF, IN2M_OFF, IN1M_OFF] : Laed-off negative status (RO)
27
+ # GPIO and OTHER Registers:
28
+ 0x0F, # GPIO [0F] [GPIOD[4:1], GPIOC[4:1]] : All GPIOs as inputs
29
+ 0x00, # MISC1 [00] [0, 0, SRB1, 0, 0, 0, 0, 0] : Disable SRBM
30
+ 0x00, # MISC2 [00] [00] : Unused
31
+ 0x00, # CONFIG4 [00] [0, 0, 0, 0, SINGLE_SHOT, 0, PD_LOFF_COMP(bar), 0] : Single-shot, lead-off comparator disabled
32
+ ]
33
+
34
+ FRONTEND_CONFIG = [
35
+ 0x3E, # ID (RO)
36
+ 0x95, # CONFIG1 [95] [1, DAISY_EN(bar), CLK_EN, 1, 0, DR[2:0]] : Datarate = 500 SPS
37
+ 0xD0, # CONFIG2 [C0] [1, 1, 0, INT_CAL, 0, CAL_AMP0, CAL_FREQ[1:0]]
38
+ 0xFC, # CONFIG3 [E0] [PD_REFBUF(bar), 1, 1, BIAS_MEAS, BIASREF_INT, PD_BIAS(bar), BIAS_LOFF_SENS, BIAS_STAT] : Power-down reference buffer, no bias
39
+ 0x00, # No lead-off
40
+ 0x62, # CH1SET [60] [PD1, GAIN1[2:0], SRB2, MUX1[2:0]] set to measure BIAS signal
41
+ 0x60, # CH2SET
42
+ 0x60, # CH3SET
43
+ 0x60, # CH4SET
44
+ 0x60, # CH5SET
45
+ 0x60, # CH6SET
46
+ 0x60, # CH7SET
47
+ 0x60, # CH8SET
48
+ 0x00, # BIAS_SENSP 00
49
+ 0x00, # BIAS_SENSN 00
50
+ 0x00, # LOFF_SENSP Lead-off on all positive pins?
51
+ 0x00, # LOFF_SENSN Lead-off on all negative pins?
52
+ 0x00, # Normal lead-off
53
+ 0x00, # Lead-off positive status (RO)
54
+ 0x00, # Lead-off negative status (RO)
55
+ 0x00, # All GPIOs as output ?
56
+ 0x20, # Enable SRB1
57
+ ]
58
+
59
+
60
+ LEADOFF_CONFIG = [
61
+ 0x3E, # ID (RO)
62
+ 0x95, # CONFIG1 [95] [1, DAISY_EN(bar), CLK_EN, 1, 0, DR[2:0]] : Datarate = 500 SPS
63
+ 0xC0, # CONFIG2 [C0] [1, 1, 0, INT_CAL, 0, CAL_AMP0, CAL_FREQ[1:0]]
64
+ 0xFC, # CONFIG3 [E0] [PD_REFBUF(bar), 1, 1, BIAS_MEAS, BIASREF_INT, PD_BIAS(bar), BIAS_LOFF_SENS, BIAS_STAT] : Power-down reference buffer, no bias
65
+ 0x00, # No lead-off
66
+ 0x60, # CH1SET [60] [PD1, GAIN1[2:0], SRB2, MUX1[2:0]] set to measure BIAS signal
67
+ 0x60, # CH2SET
68
+ 0x60, # CH3SET
69
+ 0x60, # CH4SET
70
+ 0x60, # CH5SET
71
+ 0x60, # CH6SET
72
+ 0x60, # CH7SET
73
+ 0x60, # CH8SET
74
+ 0x00, # BIAS_SENSP 00
75
+ 0x00, # BIAS_SENSN 00
76
+ 0xFF, # LOFF_SENSP Lead-off on all positive pins?
77
+ 0xFF, # LOFF_SENSN Lead-off on all negative pins?
78
+ 0x00, # Normal lead-off
79
+ 0x00, # Lead-off positive status (RO)
80
+ 0x00, # Lead-off negative status (RO)
81
+ 0x00, # All GPIOs as output ?
82
+ 0x20, # Enable SRB1
83
+ 0x00,
84
+ 0x02,
85
+ ]
86
+
87
+ def to_ads_frequency(frequency):
88
+ possible_datarates = [250, 500, 1000, 2000, 4000, 8000, 16000]
89
+ dr = 16000
90
+ for i in possible_datarates:
91
+ if i >= frequency:
92
+ dr = i
93
+ break
94
+ return dr
95
+
96
+ def mod_config(config, datarate, channel_modes):
97
+
98
+ # datarate:
99
+
100
+ possible_datarates = [(250, 0x06),
101
+ (500, 0x05),
102
+ (1000, 0x04),
103
+ (2000, 0x03),
104
+ (4000, 0x02),
105
+ (8000, 0x01),
106
+ (16000, 0x00)]
107
+ mod_dr = 0x00
108
+ for i, j in possible_datarates:
109
+ if i >= datarate:
110
+ mod_dr = j
111
+ break
112
+
113
+ new_cf1 = config[1] & 0xF8
114
+ new_cf1 = new_cf1 | mod_dr
115
+ config[1] = new_cf1
116
+
117
+ # bias:
118
+ assert len(channel_modes) == 7
119
+ config[13] = 0x00 # clear BIAS_SENSP
120
+ config[14] = 0x00 # clear BIAS_SENSN
121
+ for chan_i, chan_mode in enumerate(channel_modes):
122
+ n = 6 + chan_i
123
+ mod = config[n] & 0x78 # clear PDn and MUX[2:0]
124
+ if chan_mode == 'simple':
125
+ # If channel is activated, we send the channel's output to the BIAS mechanism
126
+ bit_i = 1 << chan_i + 1
127
+ config[13] = config[13] | bit_i
128
+ config[14] = config[14] | bit_i
129
+ elif chan_mode == 'disabled':
130
+ mod = mod | 0x81 # PDn = 1 and input shorted (001)
131
+ else:
132
+ assert False, f"Wrong key: {chan_mode}."
133
+ config[n] = mod
134
+ # for n, c in enumerate(config): # print ADS1299 configuration registers
135
+ # print(f"config[{n}]:\t{c:08b}\t({hex(c)})")
136
+ return config
portiloop/src/demo/demo.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from portiloop.src.demo.offline import run_offline
4
+
5
+
6
+ def on_upload_file(file):
7
+ # Check if file extension is .xdf
8
+ if file.name.split(".")[-1] != "xdf":
9
+ raise gr.Error("Please upload a .xdf file.")
10
+ else:
11
+ return file.name
12
+
13
+
14
+ def main():
15
+ with gr.Blocks(title="Portiloop") as demo:
16
+ gr.Markdown("# Portiloop Demo")
17
+ gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
18
+ gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
19
+
20
+ with gr.Row():
21
+ xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
22
+ xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
23
+
24
+ xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
25
+
26
+ # Make a checkbox group for the options
27
+ detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
28
+
29
+ # Threshold value
30
+ threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
31
+ # Detection Channel
32
+ detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
33
+ # Frequency
34
+ freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
35
+
36
+ output_array = gr.File(label="Output CSV File")
37
+
38
+ run_inference = gr.Button(value="Run Inference")
39
+ run_inference.click(
40
+ fn=run_offline,
41
+ inputs=[
42
+ xdf_file_static,
43
+ detect_filter,
44
+ threshold,
45
+ detect_channel,
46
+ freq],
47
+ outputs=[output_array])
48
+
49
+ demo.queue()
50
+ demo.launch(share=True)
51
+
52
+ if __name__ == "__main__":
53
+ main()
portiloop/src/demo/offline.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from portiloop.src.detection import SleepSpindleRealTimeDetector
4
+ plt.switch_backend('agg')
5
+ from portiloop.src.processing import FilterPipeline
6
+ from portiloop.src.demo.utils import xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
7
+ import gradio as gr
8
+
9
+
10
+ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
11
+ # Get the options from the checkbox group
12
+ offline_filtering = 0 in detect_filter_opts
13
+ lacourse = 1 in detect_filter_opts
14
+ wamsley = 2 in detect_filter_opts
15
+ online_filtering = 3 in detect_filter_opts
16
+ online_detection = 4 in detect_filter_opts
17
+
18
+ # Make sure the inputs make sense:
19
+ if not offline_filtering and (lacourse or wamsley):
20
+ raise gr.Error("You can't use the offline detection methods without offline filtering.")
21
+
22
+ if not online_filtering and online_detection:
23
+ raise gr.Error("You can't use the online detection without online filtering.")
24
+
25
+ if xdf_file is None:
26
+ raise gr.Error("Please upload a .xdf file.")
27
+
28
+ freq = int(freq)
29
+
30
+ # Read the xdf file to a numpy array
31
+ print("Loading xdf file...")
32
+ data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
33
+ # Do the offline filtering of the data
34
+ if offline_filtering:
35
+ print("Filtering offline...")
36
+ offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
37
+ # Expand the dimension of the filtered data to match the shape of the other columns
38
+ offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
39
+ data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
40
+ columns.append("offline_filtered_signal")
41
+
42
+ # Do Wamsley's method
43
+ if wamsley:
44
+ print("Running Wamsley detection...")
45
+ wamsley_data = offline_detect("Wamsley", \
46
+ data_whole[:, columns.index("offline_filtered_signal")],\
47
+ data_whole[:, columns.index("time_stamps")],\
48
+ freq)
49
+ wamsley_data = np.expand_dims(wamsley_data, axis=1)
50
+ data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
51
+ columns.append("wamsley_spindles")
52
+
53
+ # Do Lacourse's method
54
+ if lacourse:
55
+ print("Running Lacourse detection...")
56
+ lacourse_data = offline_detect("Lacourse", \
57
+ data_whole[:, columns.index("offline_filtered_signal")],\
58
+ data_whole[:, columns.index("time_stamps")],\
59
+ freq)
60
+ lacourse_data = np.expand_dims(lacourse_data, axis=1)
61
+ data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
62
+ columns.append("lacourse_spindles")
63
+
64
+ # Get the data from the raw signal column
65
+ data = data_whole[:, columns.index("raw_signal")]
66
+
67
+ # Create the online filtering pipeline
68
+ if online_filtering:
69
+ filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
70
+
71
+ # Create the detector
72
+ if online_detection:
73
+ detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
74
+ stimulator = OfflineSleepSpindleRealTimeStimulator()
75
+
76
+ if online_filtering or online_detection:
77
+ print("Running online filtering and detection...")
78
+
79
+ points = []
80
+ online_activations = []
81
+
82
+ # Go through the data
83
+ for index, point in enumerate(data):
84
+ # Filter the data
85
+ if online_filtering:
86
+ filtered_point = filter.filter(np.array([point]))
87
+ else:
88
+ filtered_point = point
89
+ filtered_point = filtered_point.tolist()
90
+ points.append(filtered_point[0])
91
+
92
+ if online_detection:
93
+ # Detect the spindles
94
+ result = detector.detect([filtered_point])
95
+
96
+ # Stimulate if necessary
97
+ stim = stimulator.stimulate(result)
98
+ if stim:
99
+ online_activations.append(1)
100
+ else:
101
+ online_activations.append(0)
102
+
103
+ if online_filtering:
104
+ online_filtered = np.array(points)
105
+ online_filtered = np.expand_dims(online_filtered, axis=1)
106
+ data_whole = np.concatenate((data_whole, online_filtered), axis=1)
107
+ columns.append("online_filtered_signal")
108
+
109
+ if online_detection:
110
+ online_activations = np.array(online_activations)
111
+ online_activations = np.expand_dims(online_activations, axis=1)
112
+ data_whole = np.concatenate((data_whole, online_activations), axis=1)
113
+ columns.append("online_stimulations")
114
+
115
+ print("Saving output...")
116
+ # Output the data to a csv file
117
+ np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
118
+
119
+ print("Done!")
120
+ return "output.csv"
portiloop/src/demo/test_offline.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import unittest
3
+ from portiloop.src.demo.offline import run_offline
4
+ from pathlib import Path
5
+
6
+
7
+ class TestOffline(unittest.TestCase):
8
+
9
+ def setUp(self):
10
+ combinatorial_config = {
11
+ 'offline_filtering': [True, False],
12
+ 'online_filtering': [True, False],
13
+ 'online_detection': [True, False],
14
+ 'wamsley': [True, False],
15
+ 'lacourse': [True, False],
16
+ }
17
+
18
+ self.exclusives = [("duplicate_as_window", "use_cnn_encoder")]
19
+
20
+ keys = list(combinatorial_config)
21
+ all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
22
+ all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
23
+ self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
24
+ self.xdf_file = Path(__file__).parents[3] / "test_xdf" / "test_file.xdf"
25
+
26
+
27
+ def test_all_options(self):
28
+ for config in self.filtered_options:
29
+ if config['online_detection']:
30
+ self.assertTrue(config['online_filtering'])
31
+
32
+ def test_single_option(self):
33
+ res = list(run_offline(
34
+ self.xdf_file,
35
+ offline_filtering=True,
36
+ online_filtering=True,
37
+ online_detection=True,
38
+ wamsley=True,
39
+ lacourse=True,
40
+ threshold=0.5,
41
+ channel_num=2,
42
+ freq=250))
43
+ print(res)
44
+
45
+ def tearDown(self):
46
+ pass
47
+
48
+ if __name__ == '__main__':
49
+ unittest.main()
portiloop/src/demo/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pyxdf
3
+ from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012
4
+ from scipy.signal import butter, filtfilt, iirnotch, detrend
5
+ import time
6
+ from portiloop.src.stimulation import Stimulator
7
+
8
+
9
+ STREAM_NAMES = {
10
+ 'filtered_data': 'Portiloop Filtered',
11
+ 'raw_data': 'Portiloop Raw Data',
12
+ 'stimuli': 'Portiloop_stimuli'
13
+ }
14
+
15
+
16
+ class OfflineSleepSpindleRealTimeStimulator(Stimulator):
17
+ def __init__(self):
18
+ self.last_detected_ts = time.time()
19
+ self.wait_t = 0.4 # 400 ms
20
+ self.delayer = None
21
+
22
+ def stimulate(self, detection_signal):
23
+ stim = False
24
+ for sig in detection_signal:
25
+ # We detect a stimulation
26
+ if sig:
27
+ # Record time of stimulation
28
+ ts = time.time()
29
+
30
+ # Check if time since last stimulation is long enough
31
+ if ts - self.last_detected_ts > self.wait_t:
32
+ if self.delayer is not None:
33
+ # If we have a delayer, notify it
34
+ self.delayer.detected()
35
+ stim = True
36
+
37
+ self.last_detected_ts = ts
38
+ return stim
39
+
40
+ def add_delayer(self, delayer):
41
+ self.delayer = delayer
42
+ self.delayer.stimulate = lambda: True
43
+
44
+ def xdf2array(xdf_path, channel):
45
+ xdf_data, _ = pyxdf.load_xdf(xdf_path)
46
+
47
+ # Load all streams given their names
48
+ filtered_stream, raw_stream, markers = None, None, None
49
+ for stream in xdf_data:
50
+ # print(stream['info']['name'])
51
+ if stream['info']['name'][0] == STREAM_NAMES['filtered_data']:
52
+ filtered_stream = stream
53
+ elif stream['info']['name'][0] == STREAM_NAMES['raw_data']:
54
+ raw_stream = stream
55
+ elif stream['info']['name'][0] == STREAM_NAMES['stimuli']:
56
+ markers = stream
57
+
58
+ if filtered_stream is None or raw_stream is None:
59
+ raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording")
60
+
61
+ # Add all samples from raw and filtered signals
62
+ csv_list = []
63
+ diffs = []
64
+ shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]),
65
+ int(raw_stream['footer']['info']['sample_count'][0]))
66
+ for i in range(shortest_stream):
67
+ if markers is not None:
68
+ datapoint = [filtered_stream['time_stamps'][i],
69
+ float(filtered_stream['time_series'][i, channel-1]),
70
+ raw_stream['time_series'][i, channel-1],
71
+ 0]
72
+ else:
73
+ datapoint = [filtered_stream['time_stamps'][i],
74
+ float(filtered_stream['time_series'][i, channel-1]),
75
+ raw_stream['time_series'][i, channel-1]]
76
+ diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i]))
77
+ csv_list.append(datapoint)
78
+
79
+ # Add markers
80
+ columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"]
81
+ if markers is not None:
82
+ columns.append("online_stimulations_portiloop")
83
+ for time_stamp in markers['time_stamps']:
84
+ new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin()
85
+ csv_list[new_index][3] = 1
86
+
87
+ return np.array(csv_list), columns
88
+
89
+
90
+ def offline_detect(method, data, timesteps, freq):
91
+ # Get the spindle data from the offline methods
92
+ time = np.arange(0, len(data)) / freq
93
+ if method == "Lacourse":
94
+ detector = DetectSpindle(method='Lacourse2018')
95
+ spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
96
+ elif method == "Wamsley":
97
+ detector = DetectSpindle(method='Wamsley2012')
98
+ spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
99
+ else:
100
+ raise ValueError("Invalid method")
101
+
102
+ # Convert the spindle data to a numpy array
103
+ spindle_result = np.zeros(data.shape)
104
+ for spindle in spindles:
105
+ start = spindle["start"]
106
+ end = spindle["end"]
107
+ # Find index of timestep closest to start and end
108
+ start_index = np.argmin(np.abs(timesteps - start))
109
+ end_index = np.argmin(np.abs(timesteps - end))
110
+ spindle_result[start_index:end_index] = 1
111
+ return spindle_result
112
+
113
+
114
+ def offline_filter(signal, freq):
115
+
116
+ # Notch filter
117
+ f0 = 60.0 # Frequency to be removed from signal (Hz)
118
+ Q = 100.0 # Quality factor
119
+ b, a = iirnotch(f0, Q, freq)
120
+ signal = filtfilt(b, a, signal)
121
+
122
+ # Bandpass filter
123
+ lowcut = 0.5
124
+ highcut = 40.0
125
+ order = 4
126
+ b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass')
127
+ signal = filtfilt(b, a, signal)
128
+
129
+ # Detrend the signal
130
+ signal = detrend(signal)
131
+
132
+ return signal
portiloop/{detection.py β†’ src/detection.py} RENAMED
@@ -1,8 +1,12 @@
1
  from abc import ABC, abstractmethod
2
  import time
3
  from pathlib import Path
 
4
 
5
- from pycoral.utils import edgetpu
 
 
 
6
  import numpy as np
7
 
8
 
@@ -39,7 +43,7 @@ class Detector(ABC):
39
 
40
  # Example implementation for sleep spindles:
41
 
42
- DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
43
  # print(DEFAULT_MODEL_PATH)
44
 
45
  class SleepSpindleRealTimeDetector(Detector):
@@ -51,7 +55,10 @@ class SleepSpindleRealTimeDetector(Detector):
51
 
52
  self.interpreters = []
53
  for i in range(self.num_models_parallel):
54
- self.interpreters.append(edgetpu.make_interpreter(model_path))
 
 
 
55
  self.interpreters[i].allocate_tensors()
56
  self.interpreter_counter = 0
57
 
@@ -74,6 +81,10 @@ class SleepSpindleRealTimeDetector(Detector):
74
  super().__init__(threshold)
75
 
76
  def detect(self, datapoints):
 
 
 
 
77
  res = []
78
  for inp in datapoints:
79
  result = self.add_datapoint(inp)
@@ -143,7 +154,3 @@ class SleepSpindleRealTimeDetector(Detector):
143
  print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
144
 
145
  return output_data_y, output_data_h
146
-
147
-
148
-
149
-
 
1
  from abc import ABC, abstractmethod
2
  import time
3
  from pathlib import Path
4
+ from portiloop.src import ADS
5
 
6
+ if ADS:
7
+ from pycoral.utils import edgetpu
8
+ else:
9
+ import tensorflow as tf
10
  import numpy as np
11
 
12
 
 
43
 
44
  # Example implementation for sleep spindles:
45
 
46
+ DEFAULT_MODEL_PATH = str(Path(__file__).parent.parent / "models/portiloop_model_quant.tflite")
47
  # print(DEFAULT_MODEL_PATH)
48
 
49
  class SleepSpindleRealTimeDetector(Detector):
 
55
 
56
  self.interpreters = []
57
  for i in range(self.num_models_parallel):
58
+ if ADS:
59
+ self.interpreters.append(edgetpu.make_interpreter(model_path))
60
+ else:
61
+ self.interpreters.append(tf.lite.Interpreter(model_path=model_path))
62
  self.interpreters[i].allocate_tensors()
63
  self.interpreter_counter = 0
64
 
 
81
  super().__init__(threshold)
82
 
83
  def detect(self, datapoints):
84
+ """
85
+ Takes datapoints as input and outputs a detection signal.
86
+ datapoints is a list of lists of n channels: may contain several datapoints.
87
+ """
88
  res = []
89
  for inp in datapoints:
90
  result = self.add_datapoint(inp)
 
154
  print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
155
 
156
  return output_data_y, output_data_h
 
 
 
 
portiloop/{hardware β†’ src/hardware}/__init__.py RENAMED
File without changes
portiloop/{demo β†’ src/hardware/demo}/acquisition_demo.py RENAMED
File without changes
portiloop/{nn β†’ src/hardware/demo}/demo_net.py RENAMED
File without changes
portiloop/{demo β†’ src/hardware/demo}/led_demo.py RENAMED
File without changes
portiloop/{hardware β†’ src/hardware}/frontend.py RENAMED
File without changes
portiloop/{hardware β†’ src/hardware}/leds.py RENAMED
File without changes
portiloop/src/processing.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.signal import firwin
3
+
4
+
5
+ def filter_24(value):
6
+ return (value * 4.5) / (2**23 - 1) / 24.0 * 1e6 # 23 because 1 bit is lost for sign
7
+
8
+ def filter_2scomplement_np(value):
9
+ return np.where((value & (1 << 23)) != 0, value - (1 << 24), value)
10
+
11
+
12
+ def int_to_float(value):
13
+ """
14
+ Convert the int value out of the ADS into a value in microvolts
15
+ """
16
+ return filter_24(filter_2scomplement_np(value))
17
+
18
+
19
+ def shift_numpy(arr, num, fill_value=np.nan):
20
+ result = np.empty_like(arr)
21
+ if num > 0:
22
+ result[:num] = fill_value
23
+ result[num:] = arr[:-num]
24
+ elif num < 0:
25
+ result[num:] = fill_value
26
+ result[:num] = arr[-num:]
27
+ else:
28
+ result[:] = arr
29
+ return result
30
+
31
+
32
+ class FIR:
33
+ def __init__(self, nb_channels, coefficients, buffer=None):
34
+
35
+ self.coefficients = np.expand_dims(np.array(coefficients), axis=1)
36
+ self.taps = len(self.coefficients)
37
+ self.nb_channels = nb_channels
38
+ self.buffer = np.array(buffer) if buffer is not None else np.zeros((self.taps, self.nb_channels))
39
+
40
+ def filter(self, x):
41
+ self.buffer = shift_numpy(self.buffer, 1, x)
42
+ filtered = np.sum(self.buffer * self.coefficients, axis=0)
43
+ return filtered
44
+
45
+
46
+ class FilterPipeline:
47
+ def __init__(self,
48
+ nb_channels,
49
+ sampling_rate,
50
+ power_line_fq=60,
51
+ use_custom_fir=False,
52
+ custom_fir_order=20,
53
+ custom_fir_cutoff=30,
54
+ alpha_avg=0.1,
55
+ alpha_std=0.001,
56
+ epsilon=0.000001,
57
+ filter_args=[]):
58
+ if len(filter_args) > 0:
59
+ use_fir, use_notch, use_std = filter_args
60
+ else:
61
+ use_fir=True,
62
+ use_notch=True,
63
+ use_std=True
64
+ self.use_fir = use_fir
65
+ self.use_notch = use_notch
66
+ self.use_std = use_std
67
+ self.nb_channels = nb_channels
68
+ assert power_line_fq in [50, 60], f"The only supported power line frequencies are 50 Hz and 60 Hz"
69
+ if power_line_fq == 60:
70
+ self.notch_coeff1 = -0.12478308884588535
71
+ self.notch_coeff2 = 0.98729186796473023
72
+ self.notch_coeff3 = 0.99364593398236511
73
+ self.notch_coeff4 = -0.12478308884588535
74
+ self.notch_coeff5 = 0.99364593398236511
75
+ else:
76
+ self.notch_coeff1 = -0.61410695998423581
77
+ self.notch_coeff2 = 0.98729186796473023
78
+ self.notch_coeff3 = 0.99364593398236511
79
+ self.notch_coeff4 = -0.61410695998423581
80
+ self.notch_coeff5 = 0.99364593398236511
81
+ self.dfs = [np.zeros(self.nb_channels), np.zeros(self.nb_channels)]
82
+
83
+ self.moving_average = None
84
+ self.moving_variance = np.zeros(self.nb_channels)
85
+ self.ALPHA_AVG = alpha_avg
86
+ self.ALPHA_STD = alpha_std
87
+ self.EPSILON = epsilon
88
+
89
+ if use_custom_fir:
90
+ self.fir_coef = firwin(numtaps=custom_fir_order+1, cutoff=custom_fir_cutoff, fs=sampling_rate)
91
+ else:
92
+ self.fir_coef = [
93
+ 0.001623780150148094927192721215192250384,
94
+ 0.014988684599373741992978104065059596905,
95
+ 0.021287595318265635502275046064823982306,
96
+ 0.007349500393709578957568417933998716762,
97
+ -0.025127515717112181709014251396183681209,
98
+ -0.052210507359822452833064687638398027048,
99
+ -0.039273839505489904766477593511808663607,
100
+ 0.033021568427940004020193498490698402748,
101
+ 0.147606943281569008563636202779889572412,
102
+ 0.254000252034505602516389899392379447818,
103
+ 0.297330876398883392486283128164359368384,
104
+ 0.254000252034505602516389899392379447818,
105
+ 0.147606943281569008563636202779889572412,
106
+ 0.033021568427940004020193498490698402748,
107
+ -0.039273839505489904766477593511808663607,
108
+ -0.052210507359822452833064687638398027048,
109
+ -0.025127515717112181709014251396183681209,
110
+ 0.007349500393709578957568417933998716762,
111
+ 0.021287595318265635502275046064823982306,
112
+ 0.014988684599373741992978104065059596905,
113
+ 0.001623780150148094927192721215192250384]
114
+ self.fir = FIR(self.nb_channels, self.fir_coef)
115
+
116
+ def filter(self, value):
117
+ """
118
+ value: a numpy array of shape (data series, channels)
119
+ """
120
+ for i, x in enumerate(value): # loop over the data series
121
+ # FIR:
122
+ if self.use_fir:
123
+ x = self.fir.filter(x)
124
+ # notch:
125
+ if self.use_notch:
126
+ denAccum = (x - self.notch_coeff1 * self.dfs[0]) - self.notch_coeff2 * self.dfs[1]
127
+ x = (self.notch_coeff3 * denAccum + self.notch_coeff4 * self.dfs[0]) + self.notch_coeff5 * self.dfs[1]
128
+ self.dfs[1] = self.dfs[0]
129
+ self.dfs[0] = denAccum
130
+ # standardization:
131
+ if self.use_std:
132
+ if self.moving_average is not None:
133
+ delta = x - self.moving_average
134
+ self.moving_average = self.moving_average + self.ALPHA_AVG * delta
135
+ self.moving_variance = (1 - self.ALPHA_STD) * (self.moving_variance + self.ALPHA_STD * delta**2)
136
+ moving_std = np.sqrt(self.moving_variance)
137
+ x = (x - self.moving_average) / (moving_std + self.EPSILON)
138
+ else:
139
+ self.moving_average = x
140
+ value[i] = x
141
+ return value
portiloop/{stimulation.py β†’ src/stimulation.py} RENAMED
@@ -1,10 +1,17 @@
1
  from abc import ABC, abstractmethod
 
2
  import time
3
  from threading import Thread, Lock
4
  from pathlib import Path
5
- import alsaaudio
 
 
 
 
 
 
6
  import wave
7
- import pylsl
8
 
9
 
10
  # Abstract interface for developers:
@@ -32,13 +39,11 @@ class Stimulator(ABC):
32
 
33
  class SleepSpindleRealTimeStimulator(Stimulator):
34
  def __init__(self):
35
- self._sound = Path(__file__).parent / 'sounds' / 'stimulus.wav'
36
  print(f"DEBUG:{self._sound}")
37
  self._thread = None
38
  self._lock = Lock()
39
  self.last_detected_ts = time.time()
40
- self.wait_counter = 0
41
- self.delayed = False
42
  self.wait_t = 0.4 # 400 ms
43
 
44
  lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
@@ -55,8 +60,6 @@ class SleepSpindleRealTimeStimulator(Stimulator):
55
 
56
  self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
57
  self.lsl_outlet_markers_fast = pylsl.StreamOutlet(lsl_markers_info_fast)
58
-
59
- self.delayer = None
60
 
61
  # Initialize Alsa stuff
62
  # Open WAV file and set PCM device
@@ -88,7 +91,7 @@ class SleepSpindleRealTimeStimulator(Stimulator):
88
  while data:
89
  self.wav_list.append(data)
90
  data = f.readframes(self.periodsize)
91
-
92
  def play_sound(self):
93
  '''
94
  Open the wav file and play a sound
@@ -98,39 +101,35 @@ class SleepSpindleRealTimeStimulator(Stimulator):
98
 
99
  def stimulate(self, detection_signal):
100
  for sig in detection_signal:
101
- # We are waiting for a delayed stimulation
102
- if self.delayed:
103
- if self.wait_counter >= self.wait_time:
104
- with self._lock:
105
- if self._thread is None:
106
- self._thread = Thread(target=self._t_sound, daemon=True)
107
- self._thread.start()
108
- self.delayed = False
109
- else:
110
- self.wait_counter += 1
111
  # We detect a stimulation
112
- elif sig:
113
  # Record time of stimulation
114
- self.lsl_outlet_markers_fast.push_sample(['FASTSTIM'])
115
  ts = time.time()
116
 
117
- # Prompt delayer to try and get a stimulation
118
- if self.delayer is not None:
119
- self.wait_time = self.delayer.stimulate()
120
- self.delayed = True
121
- self.wait_counter = 0
122
- continue
123
-
124
- # Stimulate if allowed
125
  if ts - self.last_detected_ts > self.wait_t:
126
- with self._lock:
127
- if self._thread is None:
128
- self._thread = Thread(target=self._t_sound, daemon=True)
129
- self._thread.start()
 
 
 
 
130
  self.last_detected_ts = ts
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def _t_sound(self):
133
- self.lsl_outlet_markers.push_sample(['STIM'])
134
  self.play_sound()
135
  with self._lock:
136
  self._thread = None
@@ -140,6 +139,94 @@ class SleepSpindleRealTimeStimulator(Stimulator):
140
  if self._thread is None:
141
  self._thread = Thread(target=self._t_sound, daemon=True)
142
  self._thread.start()
143
-
144
  def add_delayer(self, delayer):
145
  self.delayer = delayer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
+ from enum import Enum
3
  import time
4
  from threading import Thread, Lock
5
  from pathlib import Path
6
+
7
+ from portiloop.src import ADS
8
+
9
+ if ADS:
10
+ import alsaaudio
11
+ import pylsl
12
+
13
  import wave
14
+ from scipy.signal import find_peaks
15
 
16
 
17
  # Abstract interface for developers:
 
39
 
40
  class SleepSpindleRealTimeStimulator(Stimulator):
41
  def __init__(self):
42
+ self._sound = Path(__file__).parent.parent / 'sounds' / 'stimulus.wav'
43
  print(f"DEBUG:{self._sound}")
44
  self._thread = None
45
  self._lock = Lock()
46
  self.last_detected_ts = time.time()
 
 
47
  self.wait_t = 0.4 # 400 ms
48
 
49
  lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
 
60
 
61
  self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
62
  self.lsl_outlet_markers_fast = pylsl.StreamOutlet(lsl_markers_info_fast)
 
 
63
 
64
  # Initialize Alsa stuff
65
  # Open WAV file and set PCM device
 
91
  while data:
92
  self.wav_list.append(data)
93
  data = f.readframes(self.periodsize)
94
+
95
  def play_sound(self):
96
  '''
97
  Open the wav file and play a sound
 
101
 
102
  def stimulate(self, detection_signal):
103
  for sig in detection_signal:
 
 
 
 
 
 
 
 
 
 
104
  # We detect a stimulation
105
+ if sig:
106
  # Record time of stimulation
 
107
  ts = time.time()
108
 
109
+ # Check if time since last stimulation is long enough
 
 
 
 
 
 
 
110
  if ts - self.last_detected_ts > self.wait_t:
111
+ if self.delayer is not None:
112
+ # If we have a delayer, notify it
113
+ self.delayer.detected()
114
+ # Send the LSL marer for the fast stimulation
115
+ self.send_stimulation("FAST_STIM", False)
116
+ else:
117
+ self.send_stimulation("STIM", True)
118
+
119
  self.last_detected_ts = ts
120
+
121
+ def send_stimulation(self, lsl_text, sound):
122
+ # Send lsl stimulation
123
+ self.lsl_outlet_markers.push_sample([lsl_text])
124
+ # Send sound to patient
125
+ if sound:
126
+ with self._lock:
127
+ if self._thread is None:
128
+ self._thread = Thread(target=self._t_sound, daemon=True)
129
+ self._thread.start()
130
+
131
 
132
  def _t_sound(self):
 
133
  self.play_sound()
134
  with self._lock:
135
  self._thread = None
 
139
  if self._thread is None:
140
  self._thread = Thread(target=self._t_sound, daemon=True)
141
  self._thread.start()
142
+
143
  def add_delayer(self, delayer):
144
  self.delayer = delayer
145
+ self.delayer.stimulate = lambda: self.send_stimulation("DELAY_STIM", True)
146
+
147
+ # Class that delays stimulation to always stimulate peak or through
148
+ class UpStateDelayer:
149
+ def __init__(self, sample_freq, spindle_freq, peak, time_to_buffer):
150
+ '''
151
+ args:
152
+ sample_freq: int -> Sampling frequency of signal in Hz
153
+ time_to_wait: float -> Time to wait to build buffer in seconds
154
+ '''
155
+ # Get number of timesteps for a whole spindle
156
+ self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
157
+ self.sample_freq = sample_freq
158
+ self.buffer_size = 1.5 * self.spindle_timesteps
159
+ self.peak = peak
160
+ self.buffer = []
161
+ self.time_to_buffer = time_to_buffer
162
+ self.stimulate = None
163
+
164
+ self.state = States.NO_SPINDLE
165
+
166
+ def step(self, point):
167
+ '''
168
+ Step the delayer, ads a point to buffer if necessary.
169
+ Returns True if stimulation is actually done
170
+ '''
171
+ if self.state == States.NO_SPINDLE:
172
+ return False
173
+ elif self.state == States.BUFFERING:
174
+ self.buffer.append(point)
175
+ # If we are done buffering, move on to the waiting stage
176
+ if time.time() - self.time_started >= self.time_to_buffer:
177
+ # Compute the necessary time to wait
178
+ self.time_to_wait = self.compute_time_to_wait()
179
+ self.state = States.DELAYING
180
+ self.buffer = []
181
+ self.time_started = time.time()
182
+ return False
183
+ elif self.state == States.DELAYING:
184
+ # Check if we are done delaying
185
+ if time.time() - self.time_started >= self.time_to_wait:
186
+ # Actually stimulate the patient after the delay
187
+ if self.stimulate is not None:
188
+ self.stimulate()
189
+ # Reset state
190
+ self.time_to_wait = -1
191
+ self.state = States.NO_SPINDLE
192
+ return True
193
+ return False
194
+
195
+ def detected(self):
196
+ if self.state == States.NO_SPINDLE:
197
+ self.state = States.BUFFERING
198
+ self.time_started = time.time()
199
+
200
+ def compute_time_to_wait(self):
201
+ """
202
+ Computes the time we want to wait in total based on the spindle frequency and the buffer
203
+ """
204
+ # If we want to look at the valleys, we search for peaks on the inversed signal
205
+ if not self.peak:
206
+ self.buffer = -self.buffer
207
+
208
+ # Returns the index of the last peak in the buffer
209
+ peaks, _ = find_peaks(self.buffer, prominence=1)
210
+
211
+ # Compute the time until next peak and return it
212
+ return (len(self.buffer) - peaks[-1]) * (1 / self.sample_freq)
213
+
214
+ class States(Enum):
215
+ NO_SPINDLE = 0
216
+ BUFFERING = 1
217
+ DELAYING = 2
218
+
219
+
220
+ if __name__ == "__main__":
221
+ import numpy as np
222
+ import matplotlib.pyplot as plt
223
+
224
+ freq = 250
225
+ spindle_freq = 10
226
+ time = 10
227
+ x = np.linspace(0, time * np.pi, num=time*freq)
228
+ n = np.random.normal(scale=1, size=x.size)
229
+ y = np.sin(x) + n
230
+ plt.plot(x, y)
231
+ plt.show()
232
+
portiloop/src/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from EDFlib.edfwriter import EDFwriter
2
+ from portilooplot.jupyter_plot import ProgressPlot
3
+ from pathlib import Path
4
+ import numpy as np
5
+
6
+ EDF_PATH = Path.home() / 'workspace' / 'edf_recording'
7
+
8
+
9
+
10
+ class DummyAlsaMixer:
11
+ def __init__(self):
12
+ self.volume = 50
13
+
14
+ def getvolume(self):
15
+ return [self.volume]
16
+
17
+ def setvolume(self, volume):
18
+ self.volume = volume
19
+
20
+
21
+ class EDFRecorder:
22
+ def __init__(self, signal_labels):
23
+ self.filename = EDF_PATH / 'recording.edf'
24
+ self.nb_signals = 8
25
+ self.samples_per_datarecord_array = self.frequency
26
+ self.physical_max = 5
27
+ self.physical_min = -5
28
+ self.signal_labels = signal_labels
29
+
30
+ def open_recording_file(self):
31
+ nb_signals = self.nb_signals
32
+ samples_per_datarecord_array = self.samples_per_datarecord_array
33
+ physical_max = self.physical_max
34
+ physical_min = self.physical_min
35
+ signal_labels = self.signal_labels
36
+
37
+ print(f"Will store edf recording in {self.filename}")
38
+
39
+ self.edf_writer = EDFwriter(p_path=str(self.filename),
40
+ f_file_type=EDFwriter.EDFLIB_FILETYPE_EDFPLUS,
41
+ number_of_signals=nb_signals)
42
+
43
+ for signal in range(nb_signals):
44
+ assert self.edf_writer.setSampleFrequency(signal, samples_per_datarecord_array) == 0
45
+ assert self.edf_writer.setPhysicalMaximum(signal, physical_max) == 0
46
+ assert self.edf_writer.setPhysicalMinimum(signal, physical_min) == 0
47
+ assert self.edf_writer.setDigitalMaximum(signal, 32767) == 0
48
+ assert self.edf_writer.setDigitalMinimum(signal, -32768) == 0
49
+ assert self.edf_writer.setSignalLabel(signal, signal_labels[signal]) == 0
50
+ assert self.edf_writer.setPhysicalDimension(signal, 'V') == 0
51
+
52
+ def close_recording_file(self):
53
+ assert self.edf_writer.close() == 0
54
+
55
+ def add_recording_data(self, data):
56
+ self.edf_buffer += data
57
+ if len(self.edf_buffer) >= self.samples_per_datarecord_array:
58
+ datarecord_array = self.edf_buffer[:self.samples_per_datarecord_array]
59
+ self.edf_buffer = self.edf_buffer[self.samples_per_datarecord_array:]
60
+ datarecord_array = np.array(datarecord_array).transpose()
61
+ assert len(datarecord_array) == self.nb_signals, f"len(data)={len(data)}!={self.nb_signals}"
62
+ for d in datarecord_array:
63
+ assert len(d) == self.samples_per_datarecord_array, f"{len(d)}!={self.samples_per_datarecord_array}"
64
+ assert self.edf_writer.writeSamples(d) == 0
65
+
66
+
67
+ class LiveDisplay():
68
+ def __init__(self, channel_names, window_len=100):
69
+ self.datapoint_dim = len(channel_names)
70
+ self.history = []
71
+ self.pp = ProgressPlot(plot_names=channel_names, max_window_len=window_len)
72
+ self.matplotlib = False
73
+
74
+ def add_datapoints(self, datapoints):
75
+ """
76
+ Adds 8 lists of datapoints to the plot
77
+
78
+ Args:
79
+ datapoints: list of 8 lists of floats (or list of 8 floats)
80
+ """
81
+ if self.matplotlib:
82
+ import matplotlib.pyplot as plt
83
+ disp_list = []
84
+ for datapoint in datapoints:
85
+ d = [[elt] for elt in datapoint]
86
+ disp_list.append(d)
87
+
88
+ if self.matplotlib:
89
+ self.history += d[1]
90
+
91
+ if not self.matplotlib:
92
+ self.pp.update_with_datapoints(disp_list)
93
+ elif len(self.history) == 1000:
94
+ plt.plot(self.history)
95
+ plt.show()
96
+ self.history = []
97
+
98
+ def add_datapoint(self, datapoint):
99
+ disp_list = [[elt] for elt in datapoint]
100
+ self.pp.update(disp_list)
101
+
102
+
103
+ class FileReader:
104
+ def __init__(self, filename):
105
+ raise NotImplementedError
106
+
107
+ def get_point(self):
108
+ raise NotImplementedError
setup.py CHANGED
@@ -11,10 +11,17 @@ setup(
11
  'portilooplot',
12
  'ipywidgets',
13
  'python-periphery',
14
- 'spidev',
15
- 'pylsl-coral',
16
  'scipy',
17
- 'pycoral',
18
- 'pyalsaaudio'
19
- ]
 
 
 
 
 
 
 
 
 
20
  )
 
11
  'portilooplot',
12
  'ipywidgets',
13
  'python-periphery',
 
 
14
  'scipy',
15
+ 'matplotlib',
16
+ ],
17
+ extras_require={
18
+ 'Portiloop': ['pycoral',
19
+ 'spidev',
20
+ 'pylsl-coral',
21
+ 'pyalsaaudio'],
22
+ 'PC': ['gradio',
23
+ 'tensorflow',
24
+ 'pyxdf',
25
+ 'wonambi']
26
+ },
27
  )