Yann Bouteiller commited on
Commit
6da664d
·
1 Parent(s): 0671861

Added spindle stimulation mode

Browse files
Files changed (2) hide show
  1. portiloop/capture.py +98 -2
  2. portiloop/stimulation.py +29 -2
portiloop/capture.py CHANGED
@@ -429,6 +429,58 @@ class DummyAlsaMixer:
429
  self.volume = volume
430
 
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  class Capture:
433
  def __init__(self, detector_cls=None, stimulator_cls=None):
434
  # {now.strftime('%m_%d_%Y_%H_%M_%S')}
@@ -465,6 +517,8 @@ class Capture:
465
  self._t_capture = None
466
  self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
467
  self.channel_detection = 2
 
 
468
 
469
  self.detector_cls = detector_cls
470
  self.stimulator_cls = stimulator_cls
@@ -552,6 +606,21 @@ class Capture:
552
  style={'description_width': 'initial'}
553
  )
554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  self.b_accordion_channels = widgets.Accordion(
556
  children=[
557
  widgets.GridBox([
@@ -796,6 +865,8 @@ class Capture:
796
  self.b_radio_ch7.observe(self.on_b_radio_ch7, 'value')
797
  self.b_radio_ch8.observe(self.on_b_radio_ch8, 'value')
798
  self.b_channel_detect.observe(self.on_b_channel_detect, 'value')
 
 
799
  self.b_power_line.observe(self.on_b_power_line, 'value')
800
  self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
801
  self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
@@ -823,6 +894,7 @@ class Capture:
823
  widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
824
  widgets.HBox([self.b_threshold, self.b_test_stimulus]),
825
  self.b_volume,
 
826
  self.b_accordion_filter,
827
  self.b_capture,
828
  self.b_pause]))
@@ -846,6 +918,8 @@ class Capture:
846
  self.b_radio_ch8.disabled = False
847
  self.b_power_line.disabled = False
848
  self.b_channel_detect.disabled = False
 
 
849
  self.b_polyak_mean.disabled = False
850
  self.b_polyak_std.disabled = False
851
  self.b_epsilon.disabled = False
@@ -880,6 +954,8 @@ class Capture:
880
  self.b_radio_ch7.disabled = True
881
  self.b_radio_ch8.disabled = True
882
  self.b_channel_detect.disabled = True
 
 
883
  self.b_power_line.disabled = True
884
  self.b_polyak_mean.disabled = True
885
  self.b_polyak_std.disabled = True
@@ -916,7 +992,17 @@ class Capture:
916
 
917
  def on_b_channel_detect(self, value):
918
  self.channel_detection = value['new']
919
-
 
 
 
 
 
 
 
 
 
 
920
  def on_b_capture(self, value):
921
  val = value['new']
922
  if val == 'Start':
@@ -1208,6 +1294,13 @@ class Capture:
1208
 
1209
  buffer = []
1210
 
 
 
 
 
 
 
 
1211
  while True:
1212
  with self._lock_msg_out:
1213
  if self._msg_out is not None:
@@ -1238,12 +1331,15 @@ class Capture:
1238
  if lsl:
1239
  lsl_outlet_raw.push_sample(point)
1240
  lsl_outlet.push_sample(filtered_point[-1])
 
 
 
1241
 
1242
  with self._pause_detect_lock:
1243
  pause = self._pause_detect
1244
  if detector is not None and not pause:
1245
  detection_signal = detector.detect(filtered_point)
1246
- if stimulator is not None:
1247
  stimulator.stimulate(detection_signal)
1248
  with self._test_stimulus_lock:
1249
  test_stimulus = self._test_stimulus
 
429
  self.volume = volume
430
 
431
 
432
+ class UpStateDelayer:
433
+ def __init__(self, sample_freq, spindle_freq, peak):
434
+ '''
435
+ args:
436
+ buffer_size: int -> Size of desired buffer in length
437
+ sample_freq: int -> Sampling frequency of signal in Hz
438
+ '''
439
+ # Get number of timesteps for a whole spindle
440
+ self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
441
+ self.sample_freq = sample_freq
442
+ self.buffer_size = 1.5 * self.spindle_timesteps
443
+ self.peak = peak
444
+ self.buffer = []
445
+
446
+ def add_point(self, point):
447
+ '''
448
+ Adds a point to the buffer to be able to keep track of peaks
449
+ '''
450
+ self.buffer.append(point)
451
+ if len(self.buffer) > self.buffer_size:
452
+ self.buffer.pop(0)
453
+
454
+ def stimulate(self):
455
+ # Calculate how far away is last peak
456
+ last_peak = -1
457
+ count = 0
458
+ for idx, point in reversed(list(enumerate(self.buffer))):
459
+ if self.peak:
460
+ try:
461
+ sup = point >= self.buffer[idx+1]
462
+ except IndexError:
463
+ sup = False
464
+ try:
465
+ inf = point >= self.buffer[idx-1]
466
+ except IndexError:
467
+ inf = False
468
+ else:
469
+ try:
470
+ sup = point <= self.buffer[idx+1]
471
+ except IndexError:
472
+ sup = False
473
+ try:
474
+ inf = point <= self.buffer[idx-1]
475
+ except IndexError:
476
+ inf = False
477
+ if sup and inf:
478
+ last_peak = count
479
+ return self.spindle_timesteps - last_peak
480
+ count += 1
481
+ return -1
482
+
483
+
484
  class Capture:
485
  def __init__(self, detector_cls=None, stimulator_cls=None):
486
  # {now.strftime('%m_%d_%Y_%H_%M_%S')}
 
517
  self._t_capture = None
518
  self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
519
  self.channel_detection = 2
520
+ self.spindle_detection_mode = 'Fast'
521
+ self.spindle_freq = 10
522
 
523
  self.detector_cls = detector_cls
524
  self.stimulator_cls = stimulator_cls
 
606
  style={'description_width': 'initial'}
607
  )
608
 
609
+ self.b_spindle_mode = widgets.Dropdown(
610
+ options=['Fast', 'Peak', 'Through'],
611
+ value='Fast',
612
+ description='Spindle Stimulation Mode',
613
+ disabled=False,
614
+ style={'description_width': 'initial'}
615
+ )
616
+
617
+ self.b_spindle_freq = widgets.IntText(
618
+ value=self.spindle_freq,
619
+ description='Spindle Freq (Hz):',
620
+ disabled=False,
621
+ style={'description_width': 'initial'}
622
+ )
623
+
624
  self.b_accordion_channels = widgets.Accordion(
625
  children=[
626
  widgets.GridBox([
 
865
  self.b_radio_ch7.observe(self.on_b_radio_ch7, 'value')
866
  self.b_radio_ch8.observe(self.on_b_radio_ch8, 'value')
867
  self.b_channel_detect.observe(self.on_b_channel_detect, 'value')
868
+ self.b_spindle_mode.observe(self.on_b_spindle_mode, 'value')
869
+ self.b_spindle_freq.observe(self.on_b_spindle_freq, 'value')
870
  self.b_power_line.observe(self.on_b_power_line, 'value')
871
  self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
872
  self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
 
894
  widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
895
  widgets.HBox([self.b_threshold, self.b_test_stimulus]),
896
  self.b_volume,
897
+ widgets.HBox([self.b_spindle_mode, self.b_spindle_freq]),
898
  self.b_accordion_filter,
899
  self.b_capture,
900
  self.b_pause]))
 
918
  self.b_radio_ch8.disabled = False
919
  self.b_power_line.disabled = False
920
  self.b_channel_detect.disabled = False
921
+ self.b_spindle_freq.disabled = False
922
+ self.b_spindle_mode.disabled = False
923
  self.b_polyak_mean.disabled = False
924
  self.b_polyak_std.disabled = False
925
  self.b_epsilon.disabled = False
 
954
  self.b_radio_ch7.disabled = True
955
  self.b_radio_ch8.disabled = True
956
  self.b_channel_detect.disabled = True
957
+ self.b_spindle_freq.disabled = True
958
+ self.b_spindle_mode.disabled = True
959
  self.b_power_line.disabled = True
960
  self.b_polyak_mean.disabled = True
961
  self.b_polyak_std.disabled = True
 
992
 
993
  def on_b_channel_detect(self, value):
994
  self.channel_detection = value['new']
995
+
996
+ def on_b_spindle_freq(self, value):
997
+ val = value['new']
998
+ if val > 0:
999
+ self.spindle_freq = val
1000
+ else:
1001
+ self.b_spindle_freq.value = self.spindle_freq
1002
+
1003
+ def on_b_spindle_mode(self, value):
1004
+ self.spindle_detection_mode = value['new']
1005
+
1006
  def on_b_capture(self, value):
1007
  val = value['new']
1008
  if val == 'Start':
 
1294
 
1295
  buffer = []
1296
 
1297
+ if not self.spindle_detection_mode == 'Fast':
1298
+ print('here')
1299
+ stimulation_delayer = UpStateDelayer(self.frequency, self.spindle_freq, self.spindle_detection_mode == 'Peak')
1300
+ stimulator.add_delayer(stimulation_delayer)
1301
+ else:
1302
+ stimulation_delayer = None
1303
+
1304
  while True:
1305
  with self._lock_msg_out:
1306
  if self._msg_out is not None:
 
1331
  if lsl:
1332
  lsl_outlet_raw.push_sample(point)
1333
  lsl_outlet.push_sample(filtered_point[-1])
1334
+
1335
+ if stimulation_delayer is not None:
1336
+ stimulation_delayer.add_point(point[channel-1])
1337
 
1338
  with self._pause_detect_lock:
1339
  pause = self._pause_detect
1340
  if detector is not None and not pause:
1341
  detection_signal = detector.detect(filtered_point)
1342
+ if stimulator is not None:
1343
  stimulator.stimulate(detection_signal)
1344
  with self._test_stimulus_lock:
1345
  test_stimulus = self._test_stimulus
portiloop/stimulation.py CHANGED
@@ -37,6 +37,8 @@ class SleepSpindleRealTimeStimulator(Stimulator):
37
  self._thread = None
38
  self._lock = Lock()
39
  self.last_detected_ts = time.time()
 
 
40
  self.wait_t = 0.4 # 400 ms
41
 
42
  lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
@@ -45,6 +47,7 @@ class SleepSpindleRealTimeStimulator(Stimulator):
45
  channel_format='string',
46
  source_id='portiloop1') # TODO: replace this by unique device identifier
47
  self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
 
48
 
49
  # Initialize Alsa stuff
50
  # Open WAV file and set PCM device
@@ -86,11 +89,32 @@ class SleepSpindleRealTimeStimulator(Stimulator):
86
 
87
  def stimulate(self, detection_signal):
88
  for sig in detection_signal:
89
- if sig:
 
 
 
 
 
 
 
 
 
 
 
 
90
  ts = time.time()
 
 
 
 
 
 
 
 
 
91
  if ts - self.last_detected_ts > self.wait_t:
92
  with self._lock:
93
- if self._thread is None:
94
  self._thread = Thread(target=self._t_sound, daemon=True)
95
  self._thread.start()
96
  self.last_detected_ts = ts
@@ -106,3 +130,6 @@ class SleepSpindleRealTimeStimulator(Stimulator):
106
  if self._thread is None:
107
  self._thread = Thread(target=self._t_sound, daemon=True)
108
  self._thread.start()
 
 
 
 
37
  self._thread = None
38
  self._lock = Lock()
39
  self.last_detected_ts = time.time()
40
+ self.wait_counter = 0
41
+ self.delayed = False
42
  self.wait_t = 0.4 # 400 ms
43
 
44
  lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
 
47
  channel_format='string',
48
  source_id='portiloop1') # TODO: replace this by unique device identifier
49
  self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
50
+ self.delayer = None
51
 
52
  # Initialize Alsa stuff
53
  # Open WAV file and set PCM device
 
89
 
90
  def stimulate(self, detection_signal):
91
  for sig in detection_signal:
92
+ # We are waiting for a delayed stimulation
93
+ if self.delayed:
94
+ if self.wait_counter >= self.wait_time:
95
+ with self._lock:
96
+ if self._thread is None:
97
+ self._thread = Thread(target=self._t_sound, daemon=True)
98
+ self._thread.start()
99
+ self.delayed = False
100
+ else:
101
+ self.wait_counter += 1
102
+ # We detect a stimulation
103
+ elif sig:
104
+ # Record time of stimulation
105
  ts = time.time()
106
+
107
+ # Prompt delayer to try and get a stimulation
108
+ if self.delayer is not None:
109
+ self.wait_time = self.delayer.stimulate()
110
+ self.delayed = True
111
+ self.wait_counter = 0
112
+ continue
113
+
114
+ # Stimulate if allowed
115
  if ts - self.last_detected_ts > self.wait_t:
116
  with self._lock:
117
+ if self._thread is None:
118
  self._thread = Thread(target=self._t_sound, daemon=True)
119
  self._thread.start()
120
  self.last_detected_ts = ts
 
130
  if self._thread is None:
131
  self._thread = Thread(target=self._t_sound, daemon=True)
132
  self._thread.start()
133
+
134
+ def add_delayer(self, delayer):
135
+ self.delayer = delayer