Spaces:
Sleeping
Sleeping
ybouteiller
commited on
Commit
·
35cdf83
1
Parent(s):
120f728
debugged and cleaned + implement stimulation skeletton
Browse files- portiloop/capture.py +28 -17
- portiloop/{inference.py → detection.py} +73 -40
- portiloop/notebooks/tests.ipynb +5 -34
- portiloop/stimulation.py +35 -0
portiloop/capture.py
CHANGED
@@ -399,7 +399,7 @@ def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time
|
|
399 |
|
400 |
|
401 |
class Capture:
|
402 |
-
def __init__(self,
|
403 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
404 |
self.filename = EDF_PATH / 'recording.edf'
|
405 |
self._p_capture = None
|
@@ -433,7 +433,8 @@ class Capture:
|
|
433 |
self._t_capture = None
|
434 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
435 |
|
436 |
-
self.
|
|
|
437 |
|
438 |
# widgets ===============================
|
439 |
|
@@ -665,6 +666,7 @@ class Capture:
|
|
665 |
self.b_duration.observe(self.on_b_duration, 'value')
|
666 |
self.b_filter.observe(self.on_b_filter, 'value')
|
667 |
self.b_detect.observe(self.on_b_detect, 'value')
|
|
|
668 |
self.b_record.observe(self.on_b_record, 'value')
|
669 |
self.b_lsl.observe(self.on_b_lsl, 'value')
|
670 |
self.b_display.observe(self.on_b_display, 'value')
|
@@ -707,7 +709,7 @@ class Capture:
|
|
707 |
self.b_filter.disabled = False
|
708 |
self.b_detect.disabled = False
|
709 |
self.b_record.disabled = False
|
710 |
-
self.
|
711 |
self.b_display.disabled = False
|
712 |
self.b_clock.disabled = False
|
713 |
self.b_radio_ch2.disabled = False
|
@@ -733,8 +735,9 @@ class Capture:
|
|
733 |
self.b_filter.disabled = True
|
734 |
self.b_stimulate.disabled = True
|
735 |
self.b_filter.disabled = True
|
|
|
736 |
self.b_record.disabled = True
|
737 |
-
self.
|
738 |
self.b_display.disabled = True
|
739 |
self.b_clock.disabled = True
|
740 |
self.b_radio_ch2.disabled = True
|
@@ -784,8 +787,18 @@ class Capture:
|
|
784 |
if self._t_capture is not None:
|
785 |
warnings.warn("Capture already running, operation aborted.")
|
786 |
return
|
|
|
|
|
787 |
self._t_capture = Thread(target=self.start_capture,
|
788 |
-
args=(self.filter,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
self._t_capture.start()
|
790 |
elif val == 'Stop':
|
791 |
with self._lock_msg_out:
|
@@ -944,8 +957,9 @@ class Capture:
|
|
944 |
|
945 |
def start_capture(self,
|
946 |
filter,
|
947 |
-
|
948 |
-
|
|
|
949 |
record,
|
950 |
lsl,
|
951 |
viz,
|
@@ -971,8 +985,8 @@ class Capture:
|
|
971 |
alpha_std=self.polyak_std,
|
972 |
epsilon=self.epsilon)
|
973 |
|
974 |
-
if
|
975 |
-
|
976 |
|
977 |
self._p_capture = mp.Process(target=_capture_process,
|
978 |
args=(p_data_o,
|
@@ -984,7 +998,7 @@ class Capture:
|
|
984 |
self.channel_states)
|
985 |
)
|
986 |
self._p_capture.start()
|
987 |
-
|
988 |
|
989 |
if viz:
|
990 |
live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
|
@@ -1030,14 +1044,11 @@ class Capture:
|
|
1030 |
|
1031 |
filtered_point = n_array.tolist()
|
1032 |
|
1033 |
-
if
|
1034 |
-
|
1035 |
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
if stimulate and True:
|
1040 |
-
print('stimulation')
|
1041 |
|
1042 |
if lsl:
|
1043 |
lsl_outlet.push_sample(filtered_point[-1])
|
|
|
399 |
|
400 |
|
401 |
class Capture:
|
402 |
+
def __init__(self, detector_cls=None, stimulator_cls=None):
|
403 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
404 |
self.filename = EDF_PATH / 'recording.edf'
|
405 |
self._p_capture = None
|
|
|
433 |
self._t_capture = None
|
434 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
435 |
|
436 |
+
self.detector_cls = detector_cls
|
437 |
+
self.stimulator_cls = stimulator_cls
|
438 |
|
439 |
# widgets ===============================
|
440 |
|
|
|
666 |
self.b_duration.observe(self.on_b_duration, 'value')
|
667 |
self.b_filter.observe(self.on_b_filter, 'value')
|
668 |
self.b_detect.observe(self.on_b_detect, 'value')
|
669 |
+
self.b_stimulate.observe(self.on_b_stimulate, 'value')
|
670 |
self.b_record.observe(self.on_b_record, 'value')
|
671 |
self.b_lsl.observe(self.on_b_lsl, 'value')
|
672 |
self.b_display.observe(self.on_b_display, 'value')
|
|
|
709 |
self.b_filter.disabled = False
|
710 |
self.b_detect.disabled = False
|
711 |
self.b_record.disabled = False
|
712 |
+
self.b_lsl.disabled = False
|
713 |
self.b_display.disabled = False
|
714 |
self.b_clock.disabled = False
|
715 |
self.b_radio_ch2.disabled = False
|
|
|
735 |
self.b_filter.disabled = True
|
736 |
self.b_stimulate.disabled = True
|
737 |
self.b_filter.disabled = True
|
738 |
+
self.b_detect.disabled = True
|
739 |
self.b_record.disabled = True
|
740 |
+
self.b_lsl.disabled = True
|
741 |
self.b_display.disabled = True
|
742 |
self.b_clock.disabled = True
|
743 |
self.b_radio_ch2.disabled = True
|
|
|
787 |
if self._t_capture is not None:
|
788 |
warnings.warn("Capture already running, operation aborted.")
|
789 |
return
|
790 |
+
detector_cls = self.detector_cls if self.detect else None
|
791 |
+
stimulator_cls = self.stimulator_cls if self.stimulate else None
|
792 |
self._t_capture = Thread(target=self.start_capture,
|
793 |
+
args=(self.filter,
|
794 |
+
detector_cls,
|
795 |
+
self.threshold,
|
796 |
+
stimulator_cls,
|
797 |
+
self.record,
|
798 |
+
self.lsl,
|
799 |
+
self.display,
|
800 |
+
500,
|
801 |
+
self.python_clock))
|
802 |
self._t_capture.start()
|
803 |
elif val == 'Stop':
|
804 |
with self._lock_msg_out:
|
|
|
957 |
|
958 |
def start_capture(self,
|
959 |
filter,
|
960 |
+
detector_cls,
|
961 |
+
threshold,
|
962 |
+
stimulator_cls,
|
963 |
record,
|
964 |
lsl,
|
965 |
viz,
|
|
|
985 |
alpha_std=self.polyak_std,
|
986 |
epsilon=self.epsilon)
|
987 |
|
988 |
+
detector = detector_cls(threshold) if detector_cls is not None else None
|
989 |
+
stimulator = stimulator_cls() if stimulator_cls is not None else None
|
990 |
|
991 |
self._p_capture = mp.Process(target=_capture_process,
|
992 |
args=(p_data_o,
|
|
|
998 |
self.channel_states)
|
999 |
)
|
1000 |
self._p_capture.start()
|
1001 |
+
print(f"PID capture: {self._p_capture.pid}")
|
1002 |
|
1003 |
if viz:
|
1004 |
live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
|
|
|
1044 |
|
1045 |
filtered_point = n_array.tolist()
|
1046 |
|
1047 |
+
if detector is not None:
|
1048 |
+
detection_signal = detector.detect(filtered_point)
|
1049 |
|
1050 |
+
if stimulator is not None:
|
1051 |
+
stimulator.stimulate(detection_signal)
|
|
|
|
|
|
|
1052 |
|
1053 |
if lsl:
|
1054 |
lsl_outlet.push_sample(filtered_point[-1])
|
portiloop/{inference.py → detection.py}
RENAMED
@@ -1,19 +1,49 @@
|
|
1 |
-
from pycoral.utils import edgetpu
|
2 |
-
import time
|
3 |
from abc import ABC, abstractmethod
|
|
|
4 |
from pathlib import Path
|
|
|
|
|
5 |
import numpy as np
|
6 |
|
7 |
-
DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
|
8 |
-
print(DEFAULT_MODEL_PATH)
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
@abstractmethod
|
12 |
-
def
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
class
|
16 |
-
def __init__(self, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
|
17 |
model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
|
18 |
self.verbose = verbose
|
19 |
self.channel = channel
|
@@ -32,60 +62,63 @@ class QuantizedModelForInference(AbstractQuantizedModelForInference):
|
|
32 |
self.seq_stride = seq_stride
|
33 |
self.window_size = window_size
|
34 |
|
35 |
-
self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * i) for i in range(self.num_models_parallel)]
|
36 |
-
for idx
|
37 |
-
self.stride_counters[idx
|
|
|
|
|
38 |
self.current_stride_counter = self.stride_counters[0] - 1
|
39 |
|
40 |
-
|
41 |
-
|
|
|
42 |
res = []
|
43 |
-
for inp in
|
44 |
result = self.add_datapoint(inp)
|
45 |
if result is not None:
|
46 |
-
res.append(result)
|
47 |
return res
|
48 |
-
|
49 |
-
|
50 |
def add_datapoint(self, input_float):
|
51 |
-
input_float = input_float[self.channel-1]
|
52 |
result = None
|
53 |
self.buffer.append(input_float)
|
54 |
if len(self.buffer) > self.window_size:
|
55 |
self.buffer = self.buffer[1:]
|
56 |
self.current_stride_counter += 1
|
57 |
-
if self.current_stride_counter == self.
|
58 |
result = self.call_model(self.interpreter_counter, self.buffer)
|
59 |
self.interpreter_counter += 1
|
60 |
-
self.interpreter_counter %= self.
|
61 |
self.current_stride_counter = 0
|
62 |
return result
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
def call_model(self, idx, input_float=None):
|
67 |
if input_float is None:
|
68 |
-
# For
|
69 |
-
input_shape = input_details[0]['shape']
|
70 |
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
|
71 |
else:
|
72 |
# Convert float input to Int
|
73 |
-
input_scale, input_zero_point = input_details[0]["quantization"]
|
74 |
input = np.asarray(input_float) / input_scale + input_zero_point
|
75 |
-
input = input.astype(input_details[0]["dtype"])
|
76 |
-
|
77 |
-
interpreter.set_tensor(input_details[0]['index'], input)
|
78 |
-
if self.verbose:
|
79 |
-
start_time = time.time()
|
80 |
-
|
81 |
-
interpreter.invoke()
|
82 |
-
|
83 |
-
if self.verbose:
|
84 |
-
end_time = time.time()
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
if self.verbose:
|
91 |
print(f"Computed output {output} in {end_time - start_time} seconds")
|
|
|
|
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
+
import time
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
+
from pycoral.utils import edgetpu
|
6 |
import numpy as np
|
7 |
|
|
|
|
|
8 |
|
9 |
+
# Abstract interface for developers:
|
10 |
+
|
11 |
+
class Detector(ABC):
|
12 |
+
|
13 |
+
def __init__(self, threshold=None):
|
14 |
+
"""
|
15 |
+
If implementing __init__() in your subclass, it must take threshold as a keyword argument.
|
16 |
+
This is the value of the threshold that the user can set in the Portiloop GUI.
|
17 |
+
Caution: even if you don't need this manual threshold in your application,
|
18 |
+
your implementation of __init__() still needs to have this keyword argument.
|
19 |
+
"""
|
20 |
+
self.threshold = threshold
|
21 |
+
|
22 |
@abstractmethod
|
23 |
+
def detect(self, datapoints):
|
24 |
+
"""
|
25 |
+
Takes datapoints as input and outputs a detection signal.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
datapoints: list of lists of n channels: may contain several datapoints.
|
29 |
+
A datapoint is a list of n floats, 1 for each channel.
|
30 |
+
In the current version of Portiloop, there is always only one datapoint per datapoints list.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
signal: Object: output detection signal (for instance, the output of a neural network);
|
34 |
+
this output signal is the input of the Stimulator.stimulate method.
|
35 |
+
If you don't mean to use a Stimulator, you can simply return None.
|
36 |
+
"""
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
|
40 |
+
# Example implementation for sleep spindles:
|
41 |
+
|
42 |
+
DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
|
43 |
+
# print(DEFAULT_MODEL_PATH)
|
44 |
|
45 |
+
class SleepSpindleRealTimeDetector(Detector):
|
46 |
+
def __init__(self, threshold=0.5, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
|
47 |
model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
|
48 |
self.verbose = verbose
|
49 |
self.channel = channel
|
|
|
62 |
self.seq_stride = seq_stride
|
63 |
self.window_size = window_size
|
64 |
|
65 |
+
self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * (i + 1)) for i in range(self.num_models_parallel)]
|
66 |
+
for idx in reversed(range(1, len(self.stride_counters))):
|
67 |
+
self.stride_counters[idx] -= self.stride_counters[idx-1]
|
68 |
+
assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
|
69 |
+
|
70 |
self.current_stride_counter = self.stride_counters[0] - 1
|
71 |
|
72 |
+
super().__init__(threshold)
|
73 |
+
|
74 |
+
def detect(self, datapoints):
|
75 |
res = []
|
76 |
+
for inp in datapoints:
|
77 |
result = self.add_datapoint(inp)
|
78 |
if result is not None:
|
79 |
+
res.append(result >= self.threshold)
|
80 |
return res
|
81 |
+
|
|
|
82 |
def add_datapoint(self, input_float):
|
83 |
+
input_float = input_float[self.channel - 1]
|
84 |
result = None
|
85 |
self.buffer.append(input_float)
|
86 |
if len(self.buffer) > self.window_size:
|
87 |
self.buffer = self.buffer[1:]
|
88 |
self.current_stride_counter += 1
|
89 |
+
if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
|
90 |
result = self.call_model(self.interpreter_counter, self.buffer)
|
91 |
self.interpreter_counter += 1
|
92 |
+
self.interpreter_counter %= self.num_models_parallel
|
93 |
self.current_stride_counter = 0
|
94 |
return result
|
95 |
+
|
|
|
|
|
96 |
def call_model(self, idx, input_float=None):
|
97 |
if input_float is None:
|
98 |
+
# For debugging purposes
|
99 |
+
input_shape = self.input_details[0]['shape']
|
100 |
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
|
101 |
else:
|
102 |
# Convert float input to Int
|
103 |
+
input_scale, input_zero_point = self.input_details[0]["quantization"]
|
104 |
input = np.asarray(input_float) / input_scale + input_zero_point
|
105 |
+
input = input.astype(self.input_details[0]["dtype"])
|
106 |
+
input = input.reshape((1, 1, -1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
# FIXME: bad sequence length: 50 instead of 1:
|
109 |
+
# self.interpreters[idx].set_tensor(self.input_details[0]['index'], input)
|
110 |
+
#
|
111 |
+
# if self.verbose:
|
112 |
+
# start_time = time.time()
|
113 |
+
#
|
114 |
+
# self.interpreters[idx].invoke()
|
115 |
+
#
|
116 |
+
# if self.verbose:
|
117 |
+
# end_time = time.time()
|
118 |
+
# output = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
|
119 |
+
# output_scale, output_zero_point = self.input_details[0]["quantization"]
|
120 |
+
# output = float(output - output_zero_point) * output_scale
|
121 |
+
output = np.random.uniform() # FIXME: remove
|
122 |
|
123 |
if self.verbose:
|
124 |
print(f"Computed output {output} in {end_time - start_time} seconds")
|
portiloop/notebooks/tests.ipynb
CHANGED
@@ -2,47 +2,18 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "7b2fc5da",
|
7 |
"metadata": {
|
8 |
"scrolled": false
|
9 |
},
|
10 |
-
"outputs": [
|
11 |
-
{
|
12 |
-
"data": {
|
13 |
-
"application/vnd.jupyter.widget-view+json": {
|
14 |
-
"model_id": "910f8e489b6341119f4d6e17a5b2aedc",
|
15 |
-
"version_major": 2,
|
16 |
-
"version_minor": 0
|
17 |
-
},
|
18 |
-
"text/plain": [
|
19 |
-
"VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
|
20 |
-
]
|
21 |
-
},
|
22 |
-
"metadata": {},
|
23 |
-
"output_type": "display_data"
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"name": "stderr",
|
27 |
-
"output_type": "stream",
|
28 |
-
"text": [
|
29 |
-
"Process Process-1:\n",
|
30 |
-
"Traceback (most recent call last):\n",
|
31 |
-
" File \"/usr/lib/python3.7/multiprocessing/process.py\", line 297, in _bootstrap\n",
|
32 |
-
" self.run()\n",
|
33 |
-
" File \"/usr/lib/python3.7/multiprocessing/process.py\", line 99, in run\n",
|
34 |
-
" self._target(*self._args, **self._kwargs)\n",
|
35 |
-
" File \"/home/mendel/software/portiloop-software/portiloop/capture.py\", line 325, in _capture_process\n",
|
36 |
-
" assert data == [0x3E], \"The communication with the ADS cannot be established.\"\n",
|
37 |
-
"AssertionError: The communication with the ADS cannot be established.\n"
|
38 |
-
]
|
39 |
-
}
|
40 |
-
],
|
41 |
"source": [
|
42 |
"from portiloop.capture import Capture\n",
|
43 |
-
"from portiloop.
|
|
|
44 |
"\n",
|
45 |
-
"cap = Capture(
|
46 |
]
|
47 |
}
|
48 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"id": "7b2fc5da",
|
7 |
"metadata": {
|
8 |
"scrolled": false
|
9 |
},
|
10 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"source": [
|
12 |
"from portiloop.capture import Capture\n",
|
13 |
+
"from portiloop.detection import SleepSpindleRealTimeDetector\n",
|
14 |
+
"from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
|
15 |
"\n",
|
16 |
+
"cap = Capture(detector_cls=SleepSpindleRealTimeDetector, stimulator_cls=SleepSpindleRealTimeStimulator)"
|
17 |
]
|
18 |
}
|
19 |
],
|
portiloop/stimulation.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import time
|
3 |
+
|
4 |
+
|
5 |
+
# Abstract interface for developers:
|
6 |
+
|
7 |
+
class Stimulator(ABC):
|
8 |
+
|
9 |
+
@abstractmethod
|
10 |
+
def stimulate(self, detection_signal):
|
11 |
+
"""
|
12 |
+
Stimulates accordingly to the output of the Detector.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
detection_signal: Object: the output of the Detector.add_datapoints method.
|
16 |
+
"""
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
|
20 |
+
# Example implementation for sleep spindles:
|
21 |
+
|
22 |
+
class SleepSpindleRealTimeStimulator(Stimulator):
|
23 |
+
def __init__(self):
|
24 |
+
self.last_detected_ts = time.time()
|
25 |
+
self.wait_t = 0.4 # 400 ms
|
26 |
+
|
27 |
+
def stimulate(self, detection_signal):
|
28 |
+
for sig in detection_signal:
|
29 |
+
if sig:
|
30 |
+
ts = time.time()
|
31 |
+
if ts - self.last_detected_ts > self.wait_t:
|
32 |
+
print("stimulation")
|
33 |
+
else:
|
34 |
+
print("same spindle")
|
35 |
+
self.last_detected_ts = ts
|