ybouteiller commited on
Commit
45a88e4
·
1 Parent(s): ba200d7

full pipeline

Browse files
portiloop/capture.py CHANGED
@@ -12,6 +12,7 @@ import multiprocessing as mp
12
  import warnings
13
  import shutil
14
  from threading import Thread, Lock
 
15
 
16
  import matplotlib.pyplot as plt
17
  from EDFlib.edfwriter import EDFwriter
@@ -198,7 +199,7 @@ class FilterPipeline:
198
  sampling_rate,
199
  power_line_fq=60,
200
  use_custom_fir=False,
201
- custom_fir_order=10,
202
  custom_fir_cutoff=30,
203
  alpha_avg=0.1,
204
  alpha_std=0.001,
@@ -411,7 +412,7 @@ class Capture:
411
  self.polyak_std = 0.001
412
  self.epsilon = 0.000001
413
  self.custom_fir = False
414
- self.custom_fir_order = 10
415
  self.custom_fir_cutoff = 30
416
  self.filter = True
417
  self.record = False
@@ -436,6 +437,19 @@ class Capture:
436
  self.detector_cls = detector_cls
437
  self.stimulator_cls = stimulator_cls
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  # widgets ===============================
440
 
441
  # CHANNELS ------------------------------
@@ -657,6 +671,22 @@ class Capture:
657
  indent=False
658
  )
659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  # CALLBACKS ----------------------
661
 
662
  self.b_capture.observe(self.on_b_capture, 'value')
@@ -684,6 +714,8 @@ class Capture:
684
  self.b_polyak_mean.observe(self.on_b_polyak_mean, 'value')
685
  self.b_polyak_std.observe(self.on_b_polyak_std, 'value')
686
  self.b_epsilon.observe(self.on_b_epsilon, 'value')
 
 
687
 
688
  self.display_buttons()
689
 
@@ -698,7 +730,8 @@ class Capture:
698
  self.b_power_line,
699
  self.b_clock,
700
  widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
701
- self.b_threshold,
 
702
  self.b_accordion_filter,
703
  self.b_capture]))
704
 
@@ -727,6 +760,7 @@ class Capture:
727
  self.b_custom_fir_cutoff.disabled = not self.custom_fir
728
  self.b_stimulate.disabled = not self.detect
729
  self.b_threshold.disabled = not self.detect
 
730
 
731
  def disable_buttons(self):
732
  self.b_frequency.disabled = True
@@ -754,6 +788,7 @@ class Capture:
754
  self.b_custom_fir_order.disabled = True
755
  self.b_custom_fir_cutoff.disabled = True
756
  self.b_threshold.disabled = True
 
757
 
758
  def on_b_radio_ch2(self, value):
759
  self.channel_states[1] = value['new']
@@ -789,6 +824,7 @@ class Capture:
789
  return
790
  detector_cls = self.detector_cls if self.detect else None
791
  stimulator_cls = self.stimulator_cls if self.stimulate else None
 
792
  self._t_capture = Thread(target=self.start_capture,
793
  args=(self.filter,
794
  detector_cls,
@@ -918,6 +954,16 @@ class Capture:
918
  def on_b_display(self, value):
919
  val = value['new']
920
  self.display = val
 
 
 
 
 
 
 
 
 
 
921
 
922
  def open_recording_file(self):
923
  nb_signals = self.nb_signals
@@ -1011,8 +1057,9 @@ class Capture:
1011
  lsl_info = StreamInfo(name='Portiloop',
1012
  type='EEG',
1013
  channel_count=8,
 
1014
  channel_format='float32',
1015
- source_id='') # TODO: replace this by unique device identifier
1016
  lsl_outlet = StreamOutlet(lsl_info)
1017
 
1018
  buffer = []
@@ -1046,9 +1093,13 @@ class Capture:
1046
 
1047
  if detector is not None:
1048
  detection_signal = detector.detect(filtered_point)
1049
-
1050
  if stimulator is not None:
1051
  stimulator.stimulate(detection_signal)
 
 
 
 
 
1052
 
1053
  if lsl:
1054
  lsl_outlet.push_sample(filtered_point[-1])
 
12
  import warnings
13
  import shutil
14
  from threading import Thread, Lock
15
+ import alsaaudio
16
 
17
  import matplotlib.pyplot as plt
18
  from EDFlib.edfwriter import EDFwriter
 
199
  sampling_rate,
200
  power_line_fq=60,
201
  use_custom_fir=False,
202
+ custom_fir_order=20,
203
  custom_fir_cutoff=30,
204
  alpha_avg=0.1,
205
  alpha_std=0.001,
 
412
  self.polyak_std = 0.001
413
  self.epsilon = 0.000001
414
  self.custom_fir = False
415
+ self.custom_fir_order = 20
416
  self.custom_fir_cutoff = 30
417
  self.filter = True
418
  self.record = False
 
437
  self.detector_cls = detector_cls
438
  self.stimulator_cls = stimulator_cls
439
 
440
+ self._test_stimulus_lock = Lock()
441
+ self._test_stimulus = False
442
+
443
+ mixers = alsaaudio.mixers()
444
+ if 'PCM' in mixers:
445
+ self.mixer = alsaaudio.Mixer(control='PCM')
446
+ else:
447
+ assert len(mixers) > 0, 'No ALSA mixer found'
448
+ warnings.warn(f"Could not find mixer PCM, using {mixers[0]} instead.")
449
+ self.mixer = alsaaudio.Mixer(control=mixers[0])
450
+ self.volume = self.mixer.getvolume()[0] # we will set the same volume on all channels
451
+
452
+
453
  # widgets ===============================
454
 
455
  # CHANNELS ------------------------------
 
671
  indent=False
672
  )
673
 
674
+ self.b_volume = widgets.IntSlider(
675
+ value=self.volume,
676
+ min=0,
677
+ max=100,
678
+ step=1,
679
+ description="Volume",
680
+ disabled=False
681
+ )
682
+
683
+ self.b_test_stimulus = widgets.Button(
684
+ description='Test stimulus',
685
+ disabled=True,
686
+ button_style='', # 'success', 'info', 'warning', 'danger' or ''
687
+ tooltip='Send a test stimulus'
688
+ )
689
+
690
  # CALLBACKS ----------------------
691
 
692
  self.b_capture.observe(self.on_b_capture, 'value')
 
714
  self.b_polyak_mean.observe(self.on_b_polyak_mean, 'value')
715
  self.b_polyak_std.observe(self.on_b_polyak_std, 'value')
716
  self.b_epsilon.observe(self.on_b_epsilon, 'value')
717
+ self.b_volume.observe(self.on_b_volume, 'value')
718
+ self.b_test_stimulus.on_click(self.on_b_test_stimulus)
719
 
720
  self.display_buttons()
721
 
 
730
  self.b_power_line,
731
  self.b_clock,
732
  widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
733
+ widgets.HBox([self.b_threshold, self.b_test_stimulus]),
734
+ self.b_volume,
735
  self.b_accordion_filter,
736
  self.b_capture]))
737
 
 
760
  self.b_custom_fir_cutoff.disabled = not self.custom_fir
761
  self.b_stimulate.disabled = not self.detect
762
  self.b_threshold.disabled = not self.detect
763
+ self.b_test_stimulus.disabled = True # only enabled when running
764
 
765
  def disable_buttons(self):
766
  self.b_frequency.disabled = True
 
788
  self.b_custom_fir_order.disabled = True
789
  self.b_custom_fir_cutoff.disabled = True
790
  self.b_threshold.disabled = True
791
+ self.b_test_stimulus.disabled = not self.stimulate # only enabled when running
792
 
793
  def on_b_radio_ch2(self, value):
794
  self.channel_states[1] = value['new']
 
824
  return
825
  detector_cls = self.detector_cls if self.detect else None
826
  stimulator_cls = self.stimulator_cls if self.stimulate else None
827
+
828
  self._t_capture = Thread(target=self.start_capture,
829
  args=(self.filter,
830
  detector_cls,
 
954
  def on_b_display(self, value):
955
  val = value['new']
956
  self.display = val
957
+
958
+ def on_b_volume(self, value):
959
+ val = value['new']
960
+ if val >= 0 and val <= 100:
961
+ self.volume = val
962
+ self.mixer.setvolume(self.volume)
963
+
964
+ def on_b_test_stimulus(self, b):
965
+ with self._test_stimulus_lock:
966
+ self._test_stimulus = True
967
 
968
  def open_recording_file(self):
969
  nb_signals = self.nb_signals
 
1057
  lsl_info = StreamInfo(name='Portiloop',
1058
  type='EEG',
1059
  channel_count=8,
1060
+ nominal_srate=self.frequency,
1061
  channel_format='float32',
1062
+ source_id='portiloop1') # TODO: replace this by unique device identifier
1063
  lsl_outlet = StreamOutlet(lsl_info)
1064
 
1065
  buffer = []
 
1093
 
1094
  if detector is not None:
1095
  detection_signal = detector.detect(filtered_point)
 
1096
  if stimulator is not None:
1097
  stimulator.stimulate(detection_signal)
1098
+ with self._test_stimulus_lock:
1099
+ test_stimulus = self._test_stimulus
1100
+ self._test_stimulus = False
1101
+ if test_stimulus:
1102
+ stimulator.test_stimulus()
1103
 
1104
  if lsl:
1105
  lsl_outlet.push_sample(filtered_point[-1])
portiloop/detection.py CHANGED
@@ -82,51 +82,68 @@ class SleepSpindleRealTimeDetector(Detector):
82
  return res
83
 
84
  def add_datapoint(self, input_float):
 
 
 
85
  input_float = input_float[self.channel - 1]
86
  result = None
 
87
  self.buffer.append(input_float)
88
  if len(self.buffer) > self.window_size:
 
89
  self.buffer = self.buffer[1:]
90
  self.current_stride_counter += 1
91
  if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
92
- result = self.call_model(self.interpreter_counter, self.buffer)
 
93
  self.interpreter_counter += 1
94
  self.interpreter_counter %= self.num_models_parallel
95
  self.current_stride_counter = 0
96
  return result
97
-
98
- def call_model(self, idx, input_float=None):
99
- if input_float is None:
100
- # For debugging purposes
101
- input_shape = self.input_details[0]['shape']
102
- input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
103
- else:
104
- # Convert float input to Int
105
- input_scale, input_zero_point = self.input_details[0]["quantization"]
106
- input = np.asarray(input_float) / input_scale + input_zero_point
107
- input = input.astype(self.input_details[0]["dtype"])
108
- input = input.reshape((1, 1, -1))
109
-
110
- # TODO: Milo please implement this:
111
- # self.interpreters[idx].set_tensor(self.input_details[0]['index'], (self.h[idx], input))
112
 
113
- # if self.verbose:
114
- # start_time = time.time()
 
115
 
116
- # self.interpreters[idx].invoke()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- # if self.verbose:
119
- # end_time = time.time()
120
- # output, self.h[idx] = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
121
- # output_scale, output_zero_point = self.input_details[0]["quantization"]
122
- # output = float(output - output_zero_point) * output_scale
123
 
124
- # TODO: remove this line:
125
- output = np.random.uniform() # FIXME: remove
 
 
 
 
 
126
 
 
 
 
127
  if self.verbose:
128
- print(f"Computed output {output} in {end_time - start_time} seconds")
129
 
130
- return output
 
 
131
 
132
 
 
82
  return res
83
 
84
  def add_datapoint(self, input_float):
85
+ '''
86
+ Add one datapoint to the buffer
87
+ '''
88
  input_float = input_float[self.channel - 1]
89
  result = None
90
+ # Add to current buffer
91
  self.buffer.append(input_float)
92
  if len(self.buffer) > self.window_size:
93
+ # Remove the end of the buffer
94
  self.buffer = self.buffer[1:]
95
  self.current_stride_counter += 1
96
  if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
97
+ # If we have reached the next window size, we send the current buffer to the inference function and update the hidden state
98
+ result, self.h[self.interpreter_counter] = self.forward_tflite(self.interpreter_counter, self.buffer, self.h[self.interpreter_counter])
99
  self.interpreter_counter += 1
100
  self.interpreter_counter %= self.num_models_parallel
101
  self.current_stride_counter = 0
102
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ def forward_tflite(self, idx, input_x, input_h):
105
+ input_details = self.interpreters[idx].get_input_details()
106
+ output_details = self.interpreters[idx].get_output_details()
107
 
108
+ # convert input to int
109
+ input_scale, input_zero_point = input_details[1]["quantization"]
110
+ input_x = np.asarray(input_x) / input_scale + input_zero_point
111
+ input_data_x = input_x.astype(input_details[1]["dtype"])
112
+ input_data_x = np.expand_dims(input_data_x, (0, 1))
113
+
114
+ # input_scale, input_zero_point = input_details[0]["quantization"]
115
+ # input = np.asarray(input) / input_scale + input_zero_point
116
+
117
+ # Test the model on random input data.
118
+ input_shape_h = input_details[0]['shape']
119
+ input_shape_x = input_details[1]['shape']
120
+
121
+ # input_data_h = np.array(np.random.random_sample(input_shape_h), dtype=np.int8)
122
+ # input_data_x = np.array(np.random.random_sample(input_shape_x), dtype=np.int8)
123
+ self.interpreters[idx].set_tensor(input_details[0]['index'], input_h)
124
+ self.interpreters[idx].set_tensor(input_details[1]['index'], input_data_x)
125
 
126
+ if self.verbose:
127
+ start_time = time.time()
128
+
129
+ self.interpreters[idx].invoke()
 
130
 
131
+ if self.verbose:
132
+ end_time = time.time()
133
+
134
+ # The function `get_tensor()` returns a copy of the tensor data.
135
+ # Use `tensor()` in order to get a pointer to the tensor.
136
+ output_data_h = self.interpreters[idx].get_tensor(output_details[0]['index'])
137
+ output_data_y = self.interpreters[idx].get_tensor(output_details[1]['index'])
138
 
139
+ output_scale, output_zero_point = output_details[1]["quantization"]
140
+ output_data_y = float(output_data_y - output_zero_point) * output_scale
141
+
142
  if self.verbose:
143
+ print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
144
 
145
+ return output_data_y, output_data_h
146
+
147
+
148
 
149
 
portiloop/notebooks/tests.ipynb CHANGED
@@ -2,24 +2,61 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "7b2fc5da",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "source": [
12
  "from portiloop.capture import Capture\n",
13
  "from portiloop.detection import SleepSpindleRealTimeDetector\n",
14
  "from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
15
  "\n",
16
- "cap = Capture(detector_cls=SleepSpindleRealTimeDetector, stimulator_cls=SleepSpindleRealTimeStimulator)"
 
 
 
17
  ]
18
  },
19
  {
20
  "cell_type": "code",
21
  "execution_count": null,
22
- "id": "9bd24d7a",
23
  "metadata": {},
24
  "outputs": [],
25
  "source": []
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "7b2fc5da",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
+ "outputs": [
11
+ {
12
+ "data": {
13
+ "application/vnd.jupyter.widget-view+json": {
14
+ "model_id": "5bd498c14c0b47ef8fc0c7b25d6197c0",
15
+ "version_major": 2,
16
+ "version_minor": 0
17
+ },
18
+ "text/plain": [
19
+ "VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
20
+ ]
21
+ },
22
+ "metadata": {},
23
+ "output_type": "display_data"
24
+ },
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "DEBUG:/home/mendel/software/portiloop-software/portiloop/sounds/stimulus.wav\n",
30
+ "PID capture: 4311\n",
31
+ "DEBUG: new config[5]:0xe1\n",
32
+ "DEBUG: new config[6]:0xe1\n",
33
+ "DEBUG: new config[7]:0xe1\n",
34
+ "DEBUG: new config[8]:0xe1\n",
35
+ "DEBUG: new config[9]:0xe1\n",
36
+ "DEBUG: new config[10]:0xe1\n",
37
+ "DEBUG: new config[11]:0xe1\n",
38
+ "DEBUG: new config[12]:0xe1\n",
39
+ "DEBUG: new config[13]:0x0\n",
40
+ "DEBUG: new config[14]:0x0\n",
41
+ "DEBUG: new config[3]:0xe8\n"
42
+ ]
43
+ }
44
+ ],
45
  "source": [
46
  "from portiloop.capture import Capture\n",
47
  "from portiloop.detection import SleepSpindleRealTimeDetector\n",
48
  "from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
49
  "\n",
50
+ "my_detector_class = SleepSpindleRealTimeDetector # you may want to implement yours\n",
51
+ "my_stimulator_class = SleepSpindleRealTimeStimulator # you may also want to implement yours\n",
52
+ "\n",
53
+ "cap = Capture(detector_cls=my_detector_class, stimulator_cls=my_stimulator_class)"
54
  ]
55
  },
56
  {
57
  "cell_type": "code",
58
  "execution_count": null,
59
+ "id": "fd7c79a7",
60
  "metadata": {},
61
  "outputs": [],
62
  "source": []
portiloop/stimulation.py CHANGED
@@ -1,8 +1,10 @@
1
  from abc import ABC, abstractmethod
2
  import time
3
- from playsound import playsound
4
  from threading import Thread, Lock
5
  from pathlib import Path
 
 
 
6
 
7
 
8
  # Abstract interface for developers:
@@ -18,6 +20,12 @@ class Stimulator(ABC):
18
  detection_signal: Object: the output of the Detector.add_datapoints method.
19
  """
20
  raise NotImplementedError
 
 
 
 
 
 
21
 
22
 
23
  # Example implementation for sleep spindles
@@ -30,6 +38,52 @@ class SleepSpindleRealTimeStimulator(Stimulator):
30
  self._lock = Lock()
31
  self.last_detected_ts = time.time()
32
  self.wait_t = 0.4 # 400 ms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def stimulate(self, detection_signal):
35
  for sig in detection_signal:
@@ -43,6 +97,13 @@ class SleepSpindleRealTimeStimulator(Stimulator):
43
  self.last_detected_ts = ts
44
 
45
  def _t_sound(self):
46
- playsound(self._sound)
 
47
  with self._lock:
48
  self._thread = None
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
  import time
 
3
  from threading import Thread, Lock
4
  from pathlib import Path
5
+ import alsaaudio
6
+ import wave
7
+ import pylsl
8
 
9
 
10
  # Abstract interface for developers:
 
20
  detection_signal: Object: the output of the Detector.add_datapoints method.
21
  """
22
  raise NotImplementedError
23
+
24
+ def test_stimulus(self):
25
+ """
26
+ Optional: this is called when the 'Test stimulus' button is pressed.
27
+ """
28
+ pass
29
 
30
 
31
  # Example implementation for sleep spindles
 
38
  self._lock = Lock()
39
  self.last_detected_ts = time.time()
40
  self.wait_t = 0.4 # 400 ms
41
+
42
+ lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
43
+ type='Markers',
44
+ channel_count=1,
45
+ channel_format='string',
46
+ source_id='portiloop1') # TODO: replace this by unique device identifier
47
+ self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
48
+
49
+ # Initialize Alsa stuff
50
+ # Open WAV file and set PCM device
51
+ with wave.open(str(self._sound), 'rb') as f:
52
+ device = 'default'
53
+
54
+ format = None
55
+
56
+ # 8bit is unsigned in wav files
57
+ if f.getsampwidth() == 1:
58
+ format = alsaaudio.PCM_FORMAT_U8
59
+ # Otherwise we assume signed data, little endian
60
+ elif f.getsampwidth() == 2:
61
+ format = alsaaudio.PCM_FORMAT_S16_LE
62
+ elif f.getsampwidth() == 3:
63
+ format = alsaaudio.PCM_FORMAT_S24_3LE
64
+ elif f.getsampwidth() == 4:
65
+ format = alsaaudio.PCM_FORMAT_S32_LE
66
+ else:
67
+ raise ValueError('Unsupported format')
68
+
69
+ self.periodsize = f.getframerate() // 8
70
+
71
+ self.pcm = alsaaudio.PCM(channels=f.getnchannels(), rate=f.getframerate(), format=format, periodsize=self.periodsize, device=device)
72
+
73
+ # Store data in list to avoid reopening the file
74
+ data = f.readframes(self.periodsize)
75
+ self.wav_list = [data]
76
+ while data:
77
+ self.wav_list.append(data)
78
+ data = f.readframes(self.periodsize)
79
+
80
+
81
+ def play_sound(self):
82
+ '''
83
+ Open the wav file and play a sound
84
+ '''
85
+ for data in self.wav_list:
86
+ self.pcm.write(data)
87
 
88
  def stimulate(self, detection_signal):
89
  for sig in detection_signal:
 
97
  self.last_detected_ts = ts
98
 
99
  def _t_sound(self):
100
+ self.lsl_outlet_markers.push_sample(['STIM'])
101
+ self.play_sound()
102
  with self._lock:
103
  self._thread = None
104
+
105
+ def test_stimulus(self):
106
+ with self._lock:
107
+ if self._thread is None:
108
+ self._thread = Thread(target=self._t_sound, daemon=True)
109
+ self._thread.start()