ybouteiller commited on
Commit
35cdf83
·
1 Parent(s): 120f728

debugged and cleaned + implement stimulation skeletton

Browse files
portiloop/capture.py CHANGED
@@ -399,7 +399,7 @@ def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time
399
 
400
 
401
  class Capture:
402
- def __init__(self, quantInferenceClass):
403
  # {now.strftime('%m_%d_%Y_%H_%M_%S')}
404
  self.filename = EDF_PATH / 'recording.edf'
405
  self._p_capture = None
@@ -433,7 +433,8 @@ class Capture:
433
  self._t_capture = None
434
  self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
435
 
436
- self.quantInferenceClass = quantInferenceClass
 
437
 
438
  # widgets ===============================
439
 
@@ -665,6 +666,7 @@ class Capture:
665
  self.b_duration.observe(self.on_b_duration, 'value')
666
  self.b_filter.observe(self.on_b_filter, 'value')
667
  self.b_detect.observe(self.on_b_detect, 'value')
 
668
  self.b_record.observe(self.on_b_record, 'value')
669
  self.b_lsl.observe(self.on_b_lsl, 'value')
670
  self.b_display.observe(self.on_b_display, 'value')
@@ -707,7 +709,7 @@ class Capture:
707
  self.b_filter.disabled = False
708
  self.b_detect.disabled = False
709
  self.b_record.disabled = False
710
- self.b_record.lsl = False
711
  self.b_display.disabled = False
712
  self.b_clock.disabled = False
713
  self.b_radio_ch2.disabled = False
@@ -733,8 +735,9 @@ class Capture:
733
  self.b_filter.disabled = True
734
  self.b_stimulate.disabled = True
735
  self.b_filter.disabled = True
 
736
  self.b_record.disabled = True
737
- self.b_record.lsl = True
738
  self.b_display.disabled = True
739
  self.b_clock.disabled = True
740
  self.b_radio_ch2.disabled = True
@@ -784,8 +787,18 @@ class Capture:
784
  if self._t_capture is not None:
785
  warnings.warn("Capture already running, operation aborted.")
786
  return
 
 
787
  self._t_capture = Thread(target=self.start_capture,
788
- args=(self.filter, self.detect, self.quantInferenceClass, self.record, self.lsl, self.display, 500, self.python_clock))
 
 
 
 
 
 
 
 
789
  self._t_capture.start()
790
  elif val == 'Stop':
791
  with self._lock_msg_out:
@@ -944,8 +957,9 @@ class Capture:
944
 
945
  def start_capture(self,
946
  filter,
947
- detect,
948
- quantInferenceClass,
 
949
  record,
950
  lsl,
951
  viz,
@@ -971,8 +985,8 @@ class Capture:
971
  alpha_std=self.polyak_std,
972
  epsilon=self.epsilon)
973
 
974
- if detect:
975
- infer = quantInferenceClass()
976
 
977
  self._p_capture = mp.Process(target=_capture_process,
978
  args=(p_data_o,
@@ -984,7 +998,7 @@ class Capture:
984
  self.channel_states)
985
  )
986
  self._p_capture.start()
987
- # print(f"PID capture: {self._p_capture.pid}")
988
 
989
  if viz:
990
  live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
@@ -1030,14 +1044,11 @@ class Capture:
1030
 
1031
  filtered_point = n_array.tolist()
1032
 
1033
- if detect:
1034
- results = infer.add_datapoints(filtered_points)
1035
 
1036
- for r in results:
1037
- print(r >= threshold)
1038
-
1039
- if stimulate and True:
1040
- print('stimulation')
1041
 
1042
  if lsl:
1043
  lsl_outlet.push_sample(filtered_point[-1])
 
399
 
400
 
401
  class Capture:
402
+ def __init__(self, detector_cls=None, stimulator_cls=None):
403
  # {now.strftime('%m_%d_%Y_%H_%M_%S')}
404
  self.filename = EDF_PATH / 'recording.edf'
405
  self._p_capture = None
 
433
  self._t_capture = None
434
  self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
435
 
436
+ self.detector_cls = detector_cls
437
+ self.stimulator_cls = stimulator_cls
438
 
439
  # widgets ===============================
440
 
 
666
  self.b_duration.observe(self.on_b_duration, 'value')
667
  self.b_filter.observe(self.on_b_filter, 'value')
668
  self.b_detect.observe(self.on_b_detect, 'value')
669
+ self.b_stimulate.observe(self.on_b_stimulate, 'value')
670
  self.b_record.observe(self.on_b_record, 'value')
671
  self.b_lsl.observe(self.on_b_lsl, 'value')
672
  self.b_display.observe(self.on_b_display, 'value')
 
709
  self.b_filter.disabled = False
710
  self.b_detect.disabled = False
711
  self.b_record.disabled = False
712
+ self.b_lsl.disabled = False
713
  self.b_display.disabled = False
714
  self.b_clock.disabled = False
715
  self.b_radio_ch2.disabled = False
 
735
  self.b_filter.disabled = True
736
  self.b_stimulate.disabled = True
737
  self.b_filter.disabled = True
738
+ self.b_detect.disabled = True
739
  self.b_record.disabled = True
740
+ self.b_lsl.disabled = True
741
  self.b_display.disabled = True
742
  self.b_clock.disabled = True
743
  self.b_radio_ch2.disabled = True
 
787
  if self._t_capture is not None:
788
  warnings.warn("Capture already running, operation aborted.")
789
  return
790
+ detector_cls = self.detector_cls if self.detect else None
791
+ stimulator_cls = self.stimulator_cls if self.stimulate else None
792
  self._t_capture = Thread(target=self.start_capture,
793
+ args=(self.filter,
794
+ detector_cls,
795
+ self.threshold,
796
+ stimulator_cls,
797
+ self.record,
798
+ self.lsl,
799
+ self.display,
800
+ 500,
801
+ self.python_clock))
802
  self._t_capture.start()
803
  elif val == 'Stop':
804
  with self._lock_msg_out:
 
957
 
958
  def start_capture(self,
959
  filter,
960
+ detector_cls,
961
+ threshold,
962
+ stimulator_cls,
963
  record,
964
  lsl,
965
  viz,
 
985
  alpha_std=self.polyak_std,
986
  epsilon=self.epsilon)
987
 
988
+ detector = detector_cls(threshold) if detector_cls is not None else None
989
+ stimulator = stimulator_cls() if stimulator_cls is not None else None
990
 
991
  self._p_capture = mp.Process(target=_capture_process,
992
  args=(p_data_o,
 
998
  self.channel_states)
999
  )
1000
  self._p_capture.start()
1001
+ print(f"PID capture: {self._p_capture.pid}")
1002
 
1003
  if viz:
1004
  live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
 
1044
 
1045
  filtered_point = n_array.tolist()
1046
 
1047
+ if detector is not None:
1048
+ detection_signal = detector.detect(filtered_point)
1049
 
1050
+ if stimulator is not None:
1051
+ stimulator.stimulate(detection_signal)
 
 
 
1052
 
1053
  if lsl:
1054
  lsl_outlet.push_sample(filtered_point[-1])
portiloop/{inference.py → detection.py} RENAMED
@@ -1,19 +1,49 @@
1
- from pycoral.utils import edgetpu
2
- import time
3
  from abc import ABC, abstractmethod
 
4
  from pathlib import Path
 
 
5
  import numpy as np
6
 
7
- DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
8
- print(DEFAULT_MODEL_PATH)
9
 
10
- class AbstractQuantizedModelForInference(ABC):
 
 
 
 
 
 
 
 
 
 
 
 
11
  @abstractmethod
12
- def add_datapoints(self, input_float):
13
- return NotImplemented
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- class QuantizedModelForInference(AbstractQuantizedModelForInference):
16
- def __init__(self, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
17
  model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
18
  self.verbose = verbose
19
  self.channel = channel
@@ -32,60 +62,63 @@ class QuantizedModelForInference(AbstractQuantizedModelForInference):
32
  self.seq_stride = seq_stride
33
  self.window_size = window_size
34
 
35
- self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * i) for i in range(self.num_models_parallel)]
36
- for idx, i in enumerate(self.stride_counters[1:]):
37
- self.stride_counters[idx+1] = i - self.stride_counters[idx]
 
 
38
  self.current_stride_counter = self.stride_counters[0] - 1
39
 
40
-
41
- def add_datapoints(self, inputs_float):
 
42
  res = []
43
- for inp in inputs_float:
44
  result = self.add_datapoint(inp)
45
  if result is not None:
46
- res.append(result)
47
  return res
48
-
49
-
50
  def add_datapoint(self, input_float):
51
- input_float = input_float[self.channel-1]
52
  result = None
53
  self.buffer.append(input_float)
54
  if len(self.buffer) > self.window_size:
55
  self.buffer = self.buffer[1:]
56
  self.current_stride_counter += 1
57
- if self.current_stride_counter == self.stride_counter[self.interpreter_counter]:
58
  result = self.call_model(self.interpreter_counter, self.buffer)
59
  self.interpreter_counter += 1
60
- self.interpreter_counter %= self.num_model_parallel
61
  self.current_stride_counter = 0
62
  return result
63
-
64
-
65
-
66
  def call_model(self, idx, input_float=None):
67
  if input_float is None:
68
- # For debuggin purposes
69
- input_shape = input_details[0]['shape']
70
  input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
71
  else:
72
  # Convert float input to Int
73
- input_scale, input_zero_point = input_details[0]["quantization"]
74
  input = np.asarray(input_float) / input_scale + input_zero_point
75
- input = input.astype(input_details[0]["dtype"])
76
-
77
- interpreter.set_tensor(input_details[0]['index'], input)
78
- if self.verbose:
79
- start_time = time.time()
80
-
81
- interpreter.invoke()
82
-
83
- if self.verbose:
84
- end_time = time.time()
85
 
86
- output = interpreter.get_tensor(output_details[0]['index'])
87
- output_scale, output_zero_point = input_details[0]["quantization"]
88
- output = float(output - output_zero_point) * output_scale
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  if self.verbose:
91
  print(f"Computed output {output} in {end_time - start_time} seconds")
 
 
 
1
  from abc import ABC, abstractmethod
2
+ import time
3
  from pathlib import Path
4
+
5
+ from pycoral.utils import edgetpu
6
  import numpy as np
7
 
 
 
8
 
9
+ # Abstract interface for developers:
10
+
11
+ class Detector(ABC):
12
+
13
+ def __init__(self, threshold=None):
14
+ """
15
+ If implementing __init__() in your subclass, it must take threshold as a keyword argument.
16
+ This is the value of the threshold that the user can set in the Portiloop GUI.
17
+ Caution: even if you don't need this manual threshold in your application,
18
+ your implementation of __init__() still needs to have this keyword argument.
19
+ """
20
+ self.threshold = threshold
21
+
22
  @abstractmethod
23
+ def detect(self, datapoints):
24
+ """
25
+ Takes datapoints as input and outputs a detection signal.
26
+
27
+ Args:
28
+ datapoints: list of lists of n channels: may contain several datapoints.
29
+ A datapoint is a list of n floats, 1 for each channel.
30
+ In the current version of Portiloop, there is always only one datapoint per datapoints list.
31
+
32
+ Returns:
33
+ signal: Object: output detection signal (for instance, the output of a neural network);
34
+ this output signal is the input of the Stimulator.stimulate method.
35
+ If you don't mean to use a Stimulator, you can simply return None.
36
+ """
37
+ raise NotImplementedError
38
+
39
+
40
+ # Example implementation for sleep spindles:
41
+
42
+ DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
43
+ # print(DEFAULT_MODEL_PATH)
44
 
45
+ class SleepSpindleRealTimeDetector(Detector):
46
+ def __init__(self, threshold=0.5, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
47
  model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
48
  self.verbose = verbose
49
  self.channel = channel
 
62
  self.seq_stride = seq_stride
63
  self.window_size = window_size
64
 
65
+ self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * (i + 1)) for i in range(self.num_models_parallel)]
66
+ for idx in reversed(range(1, len(self.stride_counters))):
67
+ self.stride_counters[idx] -= self.stride_counters[idx-1]
68
+ assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
69
+
70
  self.current_stride_counter = self.stride_counters[0] - 1
71
 
72
+ super().__init__(threshold)
73
+
74
+ def detect(self, datapoints):
75
  res = []
76
+ for inp in datapoints:
77
  result = self.add_datapoint(inp)
78
  if result is not None:
79
+ res.append(result >= self.threshold)
80
  return res
81
+
 
82
  def add_datapoint(self, input_float):
83
+ input_float = input_float[self.channel - 1]
84
  result = None
85
  self.buffer.append(input_float)
86
  if len(self.buffer) > self.window_size:
87
  self.buffer = self.buffer[1:]
88
  self.current_stride_counter += 1
89
+ if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
90
  result = self.call_model(self.interpreter_counter, self.buffer)
91
  self.interpreter_counter += 1
92
+ self.interpreter_counter %= self.num_models_parallel
93
  self.current_stride_counter = 0
94
  return result
95
+
 
 
96
  def call_model(self, idx, input_float=None):
97
  if input_float is None:
98
+ # For debugging purposes
99
+ input_shape = self.input_details[0]['shape']
100
  input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
101
  else:
102
  # Convert float input to Int
103
+ input_scale, input_zero_point = self.input_details[0]["quantization"]
104
  input = np.asarray(input_float) / input_scale + input_zero_point
105
+ input = input.astype(self.input_details[0]["dtype"])
106
+ input = input.reshape((1, 1, -1))
 
 
 
 
 
 
 
 
107
 
108
+ # FIXME: bad sequence length: 50 instead of 1:
109
+ # self.interpreters[idx].set_tensor(self.input_details[0]['index'], input)
110
+ #
111
+ # if self.verbose:
112
+ # start_time = time.time()
113
+ #
114
+ # self.interpreters[idx].invoke()
115
+ #
116
+ # if self.verbose:
117
+ # end_time = time.time()
118
+ # output = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
119
+ # output_scale, output_zero_point = self.input_details[0]["quantization"]
120
+ # output = float(output - output_zero_point) * output_scale
121
+ output = np.random.uniform() # FIXME: remove
122
 
123
  if self.verbose:
124
  print(f"Computed output {output} in {end_time - start_time} seconds")
portiloop/notebooks/tests.ipynb CHANGED
@@ -2,47 +2,18 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "7b2fc5da",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
- "outputs": [
11
- {
12
- "data": {
13
- "application/vnd.jupyter.widget-view+json": {
14
- "model_id": "910f8e489b6341119f4d6e17a5b2aedc",
15
- "version_major": 2,
16
- "version_minor": 0
17
- },
18
- "text/plain": [
19
- "VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
20
- ]
21
- },
22
- "metadata": {},
23
- "output_type": "display_data"
24
- },
25
- {
26
- "name": "stderr",
27
- "output_type": "stream",
28
- "text": [
29
- "Process Process-1:\n",
30
- "Traceback (most recent call last):\n",
31
- " File \"/usr/lib/python3.7/multiprocessing/process.py\", line 297, in _bootstrap\n",
32
- " self.run()\n",
33
- " File \"/usr/lib/python3.7/multiprocessing/process.py\", line 99, in run\n",
34
- " self._target(*self._args, **self._kwargs)\n",
35
- " File \"/home/mendel/software/portiloop-software/portiloop/capture.py\", line 325, in _capture_process\n",
36
- " assert data == [0x3E], \"The communication with the ADS cannot be established.\"\n",
37
- "AssertionError: The communication with the ADS cannot be established.\n"
38
- ]
39
- }
40
- ],
41
  "source": [
42
  "from portiloop.capture import Capture\n",
43
- "from portiloop.inference import QuantizedModelForInference\n",
 
44
  "\n",
45
- "cap = Capture(QuantizedModelForInference)"
46
  ]
47
  }
48
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "7b2fc5da",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "source": [
12
  "from portiloop.capture import Capture\n",
13
+ "from portiloop.detection import SleepSpindleRealTimeDetector\n",
14
+ "from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
15
  "\n",
16
+ "cap = Capture(detector_cls=SleepSpindleRealTimeDetector, stimulator_cls=SleepSpindleRealTimeStimulator)"
17
  ]
18
  }
19
  ],
portiloop/stimulation.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import time
3
+
4
+
5
+ # Abstract interface for developers:
6
+
7
+ class Stimulator(ABC):
8
+
9
+ @abstractmethod
10
+ def stimulate(self, detection_signal):
11
+ """
12
+ Stimulates accordingly to the output of the Detector.
13
+
14
+ Args:
15
+ detection_signal: Object: the output of the Detector.add_datapoints method.
16
+ """
17
+ raise NotImplementedError
18
+
19
+
20
+ # Example implementation for sleep spindles:
21
+
22
+ class SleepSpindleRealTimeStimulator(Stimulator):
23
+ def __init__(self):
24
+ self.last_detected_ts = time.time()
25
+ self.wait_t = 0.4 # 400 ms
26
+
27
+ def stimulate(self, detection_signal):
28
+ for sig in detection_signal:
29
+ if sig:
30
+ ts = time.time()
31
+ if ts - self.last_detected_ts > self.wait_t:
32
+ print("stimulation")
33
+ else:
34
+ print("same spindle")
35
+ self.last_detected_ts = ts