Spaces:
Sleeping
Sleeping
Yann Bouteiller
commited on
Commit
·
6da664d
1
Parent(s):
0671861
Added spindle stimulation mode
Browse files- portiloop/capture.py +98 -2
- portiloop/stimulation.py +29 -2
portiloop/capture.py
CHANGED
@@ -429,6 +429,58 @@ class DummyAlsaMixer:
|
|
429 |
self.volume = volume
|
430 |
|
431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
class Capture:
|
433 |
def __init__(self, detector_cls=None, stimulator_cls=None):
|
434 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
@@ -465,6 +517,8 @@ class Capture:
|
|
465 |
self._t_capture = None
|
466 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
467 |
self.channel_detection = 2
|
|
|
|
|
468 |
|
469 |
self.detector_cls = detector_cls
|
470 |
self.stimulator_cls = stimulator_cls
|
@@ -552,6 +606,21 @@ class Capture:
|
|
552 |
style={'description_width': 'initial'}
|
553 |
)
|
554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
self.b_accordion_channels = widgets.Accordion(
|
556 |
children=[
|
557 |
widgets.GridBox([
|
@@ -796,6 +865,8 @@ class Capture:
|
|
796 |
self.b_radio_ch7.observe(self.on_b_radio_ch7, 'value')
|
797 |
self.b_radio_ch8.observe(self.on_b_radio_ch8, 'value')
|
798 |
self.b_channel_detect.observe(self.on_b_channel_detect, 'value')
|
|
|
|
|
799 |
self.b_power_line.observe(self.on_b_power_line, 'value')
|
800 |
self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
|
801 |
self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
|
@@ -823,6 +894,7 @@ class Capture:
|
|
823 |
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
824 |
widgets.HBox([self.b_threshold, self.b_test_stimulus]),
|
825 |
self.b_volume,
|
|
|
826 |
self.b_accordion_filter,
|
827 |
self.b_capture,
|
828 |
self.b_pause]))
|
@@ -846,6 +918,8 @@ class Capture:
|
|
846 |
self.b_radio_ch8.disabled = False
|
847 |
self.b_power_line.disabled = False
|
848 |
self.b_channel_detect.disabled = False
|
|
|
|
|
849 |
self.b_polyak_mean.disabled = False
|
850 |
self.b_polyak_std.disabled = False
|
851 |
self.b_epsilon.disabled = False
|
@@ -880,6 +954,8 @@ class Capture:
|
|
880 |
self.b_radio_ch7.disabled = True
|
881 |
self.b_radio_ch8.disabled = True
|
882 |
self.b_channel_detect.disabled = True
|
|
|
|
|
883 |
self.b_power_line.disabled = True
|
884 |
self.b_polyak_mean.disabled = True
|
885 |
self.b_polyak_std.disabled = True
|
@@ -916,7 +992,17 @@ class Capture:
|
|
916 |
|
917 |
def on_b_channel_detect(self, value):
|
918 |
self.channel_detection = value['new']
|
919 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
def on_b_capture(self, value):
|
921 |
val = value['new']
|
922 |
if val == 'Start':
|
@@ -1208,6 +1294,13 @@ class Capture:
|
|
1208 |
|
1209 |
buffer = []
|
1210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1211 |
while True:
|
1212 |
with self._lock_msg_out:
|
1213 |
if self._msg_out is not None:
|
@@ -1238,12 +1331,15 @@ class Capture:
|
|
1238 |
if lsl:
|
1239 |
lsl_outlet_raw.push_sample(point)
|
1240 |
lsl_outlet.push_sample(filtered_point[-1])
|
|
|
|
|
|
|
1241 |
|
1242 |
with self._pause_detect_lock:
|
1243 |
pause = self._pause_detect
|
1244 |
if detector is not None and not pause:
|
1245 |
detection_signal = detector.detect(filtered_point)
|
1246 |
-
if stimulator is not None:
|
1247 |
stimulator.stimulate(detection_signal)
|
1248 |
with self._test_stimulus_lock:
|
1249 |
test_stimulus = self._test_stimulus
|
|
|
429 |
self.volume = volume
|
430 |
|
431 |
|
432 |
+
class UpStateDelayer:
|
433 |
+
def __init__(self, sample_freq, spindle_freq, peak):
|
434 |
+
'''
|
435 |
+
args:
|
436 |
+
buffer_size: int -> Size of desired buffer in length
|
437 |
+
sample_freq: int -> Sampling frequency of signal in Hz
|
438 |
+
'''
|
439 |
+
# Get number of timesteps for a whole spindle
|
440 |
+
self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
|
441 |
+
self.sample_freq = sample_freq
|
442 |
+
self.buffer_size = 1.5 * self.spindle_timesteps
|
443 |
+
self.peak = peak
|
444 |
+
self.buffer = []
|
445 |
+
|
446 |
+
def add_point(self, point):
|
447 |
+
'''
|
448 |
+
Adds a point to the buffer to be able to keep track of peaks
|
449 |
+
'''
|
450 |
+
self.buffer.append(point)
|
451 |
+
if len(self.buffer) > self.buffer_size:
|
452 |
+
self.buffer.pop(0)
|
453 |
+
|
454 |
+
def stimulate(self):
|
455 |
+
# Calculate how far away is last peak
|
456 |
+
last_peak = -1
|
457 |
+
count = 0
|
458 |
+
for idx, point in reversed(list(enumerate(self.buffer))):
|
459 |
+
if self.peak:
|
460 |
+
try:
|
461 |
+
sup = point >= self.buffer[idx+1]
|
462 |
+
except IndexError:
|
463 |
+
sup = False
|
464 |
+
try:
|
465 |
+
inf = point >= self.buffer[idx-1]
|
466 |
+
except IndexError:
|
467 |
+
inf = False
|
468 |
+
else:
|
469 |
+
try:
|
470 |
+
sup = point <= self.buffer[idx+1]
|
471 |
+
except IndexError:
|
472 |
+
sup = False
|
473 |
+
try:
|
474 |
+
inf = point <= self.buffer[idx-1]
|
475 |
+
except IndexError:
|
476 |
+
inf = False
|
477 |
+
if sup and inf:
|
478 |
+
last_peak = count
|
479 |
+
return self.spindle_timesteps - last_peak
|
480 |
+
count += 1
|
481 |
+
return -1
|
482 |
+
|
483 |
+
|
484 |
class Capture:
|
485 |
def __init__(self, detector_cls=None, stimulator_cls=None):
|
486 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
|
|
517 |
self._t_capture = None
|
518 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
519 |
self.channel_detection = 2
|
520 |
+
self.spindle_detection_mode = 'Fast'
|
521 |
+
self.spindle_freq = 10
|
522 |
|
523 |
self.detector_cls = detector_cls
|
524 |
self.stimulator_cls = stimulator_cls
|
|
|
606 |
style={'description_width': 'initial'}
|
607 |
)
|
608 |
|
609 |
+
self.b_spindle_mode = widgets.Dropdown(
|
610 |
+
options=['Fast', 'Peak', 'Through'],
|
611 |
+
value='Fast',
|
612 |
+
description='Spindle Stimulation Mode',
|
613 |
+
disabled=False,
|
614 |
+
style={'description_width': 'initial'}
|
615 |
+
)
|
616 |
+
|
617 |
+
self.b_spindle_freq = widgets.IntText(
|
618 |
+
value=self.spindle_freq,
|
619 |
+
description='Spindle Freq (Hz):',
|
620 |
+
disabled=False,
|
621 |
+
style={'description_width': 'initial'}
|
622 |
+
)
|
623 |
+
|
624 |
self.b_accordion_channels = widgets.Accordion(
|
625 |
children=[
|
626 |
widgets.GridBox([
|
|
|
865 |
self.b_radio_ch7.observe(self.on_b_radio_ch7, 'value')
|
866 |
self.b_radio_ch8.observe(self.on_b_radio_ch8, 'value')
|
867 |
self.b_channel_detect.observe(self.on_b_channel_detect, 'value')
|
868 |
+
self.b_spindle_mode.observe(self.on_b_spindle_mode, 'value')
|
869 |
+
self.b_spindle_freq.observe(self.on_b_spindle_freq, 'value')
|
870 |
self.b_power_line.observe(self.on_b_power_line, 'value')
|
871 |
self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
|
872 |
self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
|
|
|
894 |
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
895 |
widgets.HBox([self.b_threshold, self.b_test_stimulus]),
|
896 |
self.b_volume,
|
897 |
+
widgets.HBox([self.b_spindle_mode, self.b_spindle_freq]),
|
898 |
self.b_accordion_filter,
|
899 |
self.b_capture,
|
900 |
self.b_pause]))
|
|
|
918 |
self.b_radio_ch8.disabled = False
|
919 |
self.b_power_line.disabled = False
|
920 |
self.b_channel_detect.disabled = False
|
921 |
+
self.b_spindle_freq.disabled = False
|
922 |
+
self.b_spindle_mode.disabled = False
|
923 |
self.b_polyak_mean.disabled = False
|
924 |
self.b_polyak_std.disabled = False
|
925 |
self.b_epsilon.disabled = False
|
|
|
954 |
self.b_radio_ch7.disabled = True
|
955 |
self.b_radio_ch8.disabled = True
|
956 |
self.b_channel_detect.disabled = True
|
957 |
+
self.b_spindle_freq.disabled = True
|
958 |
+
self.b_spindle_mode.disabled = True
|
959 |
self.b_power_line.disabled = True
|
960 |
self.b_polyak_mean.disabled = True
|
961 |
self.b_polyak_std.disabled = True
|
|
|
992 |
|
993 |
def on_b_channel_detect(self, value):
|
994 |
self.channel_detection = value['new']
|
995 |
+
|
996 |
+
def on_b_spindle_freq(self, value):
|
997 |
+
val = value['new']
|
998 |
+
if val > 0:
|
999 |
+
self.spindle_freq = val
|
1000 |
+
else:
|
1001 |
+
self.b_spindle_freq.value = self.spindle_freq
|
1002 |
+
|
1003 |
+
def on_b_spindle_mode(self, value):
|
1004 |
+
self.spindle_detection_mode = value['new']
|
1005 |
+
|
1006 |
def on_b_capture(self, value):
|
1007 |
val = value['new']
|
1008 |
if val == 'Start':
|
|
|
1294 |
|
1295 |
buffer = []
|
1296 |
|
1297 |
+
if not self.spindle_detection_mode == 'Fast':
|
1298 |
+
print('here')
|
1299 |
+
stimulation_delayer = UpStateDelayer(self.frequency, self.spindle_freq, self.spindle_detection_mode == 'Peak')
|
1300 |
+
stimulator.add_delayer(stimulation_delayer)
|
1301 |
+
else:
|
1302 |
+
stimulation_delayer = None
|
1303 |
+
|
1304 |
while True:
|
1305 |
with self._lock_msg_out:
|
1306 |
if self._msg_out is not None:
|
|
|
1331 |
if lsl:
|
1332 |
lsl_outlet_raw.push_sample(point)
|
1333 |
lsl_outlet.push_sample(filtered_point[-1])
|
1334 |
+
|
1335 |
+
if stimulation_delayer is not None:
|
1336 |
+
stimulation_delayer.add_point(point[channel-1])
|
1337 |
|
1338 |
with self._pause_detect_lock:
|
1339 |
pause = self._pause_detect
|
1340 |
if detector is not None and not pause:
|
1341 |
detection_signal = detector.detect(filtered_point)
|
1342 |
+
if stimulator is not None:
|
1343 |
stimulator.stimulate(detection_signal)
|
1344 |
with self._test_stimulus_lock:
|
1345 |
test_stimulus = self._test_stimulus
|
portiloop/stimulation.py
CHANGED
@@ -37,6 +37,8 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
37 |
self._thread = None
|
38 |
self._lock = Lock()
|
39 |
self.last_detected_ts = time.time()
|
|
|
|
|
40 |
self.wait_t = 0.4 # 400 ms
|
41 |
|
42 |
lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
|
@@ -45,6 +47,7 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
45 |
channel_format='string',
|
46 |
source_id='portiloop1') # TODO: replace this by unique device identifier
|
47 |
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
|
|
48 |
|
49 |
# Initialize Alsa stuff
|
50 |
# Open WAV file and set PCM device
|
@@ -86,11 +89,32 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
86 |
|
87 |
def stimulate(self, detection_signal):
|
88 |
for sig in detection_signal:
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
ts = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
if ts - self.last_detected_ts > self.wait_t:
|
92 |
with self._lock:
|
93 |
-
if self._thread is None:
|
94 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
95 |
self._thread.start()
|
96 |
self.last_detected_ts = ts
|
@@ -106,3 +130,6 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
106 |
if self._thread is None:
|
107 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
108 |
self._thread.start()
|
|
|
|
|
|
|
|
37 |
self._thread = None
|
38 |
self._lock = Lock()
|
39 |
self.last_detected_ts = time.time()
|
40 |
+
self.wait_counter = 0
|
41 |
+
self.delayed = False
|
42 |
self.wait_t = 0.4 # 400 ms
|
43 |
|
44 |
lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
|
|
|
47 |
channel_format='string',
|
48 |
source_id='portiloop1') # TODO: replace this by unique device identifier
|
49 |
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
50 |
+
self.delayer = None
|
51 |
|
52 |
# Initialize Alsa stuff
|
53 |
# Open WAV file and set PCM device
|
|
|
89 |
|
90 |
def stimulate(self, detection_signal):
|
91 |
for sig in detection_signal:
|
92 |
+
# We are waiting for a delayed stimulation
|
93 |
+
if self.delayed:
|
94 |
+
if self.wait_counter >= self.wait_time:
|
95 |
+
with self._lock:
|
96 |
+
if self._thread is None:
|
97 |
+
self._thread = Thread(target=self._t_sound, daemon=True)
|
98 |
+
self._thread.start()
|
99 |
+
self.delayed = False
|
100 |
+
else:
|
101 |
+
self.wait_counter += 1
|
102 |
+
# We detect a stimulation
|
103 |
+
elif sig:
|
104 |
+
# Record time of stimulation
|
105 |
ts = time.time()
|
106 |
+
|
107 |
+
# Prompt delayer to try and get a stimulation
|
108 |
+
if self.delayer is not None:
|
109 |
+
self.wait_time = self.delayer.stimulate()
|
110 |
+
self.delayed = True
|
111 |
+
self.wait_counter = 0
|
112 |
+
continue
|
113 |
+
|
114 |
+
# Stimulate if allowed
|
115 |
if ts - self.last_detected_ts > self.wait_t:
|
116 |
with self._lock:
|
117 |
+
if self._thread is None:
|
118 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
119 |
self._thread.start()
|
120 |
self.last_detected_ts = ts
|
|
|
130 |
if self._thread is None:
|
131 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
132 |
self._thread.start()
|
133 |
+
|
134 |
+
def add_delayer(self, delayer):
|
135 |
+
self.delayer = delayer
|