ybouteiller commited on
Commit
ba200d7
·
1 Parent(s): 1f23291

implemented sound stimulation + new quantized model

Browse files
portiloop/detection.py CHANGED
@@ -66,6 +66,8 @@ class SleepSpindleRealTimeDetector(Detector):
66
  for idx in reversed(range(1, len(self.stride_counters))):
67
  self.stride_counters[idx] -= self.stride_counters[idx-1]
68
  assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
 
 
69
 
70
  self.current_stride_counter = self.stride_counters[0] - 1
71
 
@@ -104,20 +106,22 @@ class SleepSpindleRealTimeDetector(Detector):
104
  input = np.asarray(input_float) / input_scale + input_zero_point
105
  input = input.astype(self.input_details[0]["dtype"])
106
  input = input.reshape((1, 1, -1))
107
-
108
- # FIXME: bad sequence length: 50 instead of 1:
109
- # self.interpreters[idx].set_tensor(self.input_details[0]['index'], input)
110
- #
111
- # if self.verbose:
112
- # start_time = time.time()
113
- #
114
- # self.interpreters[idx].invoke()
115
- #
116
- # if self.verbose:
117
- # end_time = time.time()
118
- # output = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
119
- # output_scale, output_zero_point = self.input_details[0]["quantization"]
120
- # output = float(output - output_zero_point) * output_scale
 
 
121
  output = np.random.uniform() # FIXME: remove
122
 
123
  if self.verbose:
 
66
  for idx in reversed(range(1, len(self.stride_counters))):
67
  self.stride_counters[idx] -= self.stride_counters[idx-1]
68
  assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
69
+
70
+ self.h = [np.zeros((1, 7), dtype=np.int8) for _ in range(self.num_models_parallel)]
71
 
72
  self.current_stride_counter = self.stride_counters[0] - 1
73
 
 
106
  input = np.asarray(input_float) / input_scale + input_zero_point
107
  input = input.astype(self.input_details[0]["dtype"])
108
  input = input.reshape((1, 1, -1))
109
+
110
+ # TODO: Milo please implement this:
111
+ # self.interpreters[idx].set_tensor(self.input_details[0]['index'], (self.h[idx], input))
112
+
113
+ # if self.verbose:
114
+ # start_time = time.time()
115
+
116
+ # self.interpreters[idx].invoke()
117
+
118
+ # if self.verbose:
119
+ # end_time = time.time()
120
+ # output, self.h[idx] = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
121
+ # output_scale, output_zero_point = self.input_details[0]["quantization"]
122
+ # output = float(output - output_zero_point) * output_scale
123
+
124
+ # TODO: remove this line:
125
  output = np.random.uniform() # FIXME: remove
126
 
127
  if self.verbose:
portiloop/models/portiloop_model_quant.tflite CHANGED
Binary files a/portiloop/models/portiloop_model_quant.tflite and b/portiloop/models/portiloop_model_quant.tflite differ
 
portiloop/notebooks/tests.ipynb CHANGED
@@ -15,6 +15,14 @@
15
  "\n",
16
  "cap = Capture(detector_cls=SleepSpindleRealTimeDetector, stimulator_cls=SleepSpindleRealTimeStimulator)"
17
  ]
 
 
 
 
 
 
 
 
18
  }
19
  ],
20
  "metadata": {
 
15
  "\n",
16
  "cap = Capture(detector_cls=SleepSpindleRealTimeDetector, stimulator_cls=SleepSpindleRealTimeStimulator)"
17
  ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "9bd24d7a",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": []
26
  }
27
  ],
28
  "metadata": {
portiloop/sounds/sample1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64c7aaf93d583adaae3b6942a8192709a7c342feb3db6a932f1e278a760c6037
3
+ size 889422
portiloop/sounds/stimulus.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8730f1af623d95eacaa32a50ec117a36871cbb26960db62fe5d71b7ab96ba0a8
3
+ size 454576
portiloop/stimulation.py CHANGED
@@ -1,5 +1,8 @@
1
  from abc import ABC, abstractmethod
2
  import time
 
 
 
3
 
4
 
5
  # Abstract interface for developers:
@@ -17,10 +20,14 @@ class Stimulator(ABC):
17
  raise NotImplementedError
18
 
19
 
20
- # Example implementation for sleep spindles:
21
 
22
  class SleepSpindleRealTimeStimulator(Stimulator):
23
  def __init__(self):
 
 
 
 
24
  self.last_detected_ts = time.time()
25
  self.wait_t = 0.4 # 400 ms
26
 
@@ -29,7 +36,13 @@ class SleepSpindleRealTimeStimulator(Stimulator):
29
  if sig:
30
  ts = time.time()
31
  if ts - self.last_detected_ts > self.wait_t:
32
- print("stimulation")
33
- else:
34
- print("same spindle")
 
35
  self.last_detected_ts = ts
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
  import time
3
+ from playsound import playsound
4
+ from threading import Thread, Lock
5
+ from pathlib import Path
6
 
7
 
8
  # Abstract interface for developers:
 
20
  raise NotImplementedError
21
 
22
 
23
+ # Example implementation for sleep spindles
24
 
25
  class SleepSpindleRealTimeStimulator(Stimulator):
26
  def __init__(self):
27
+ self._sound = Path(__file__).parent / 'sounds' / 'stimulus.wav'
28
+ print(f"DEBUG:{self._sound}")
29
+ self._thread = None
30
+ self._lock = Lock()
31
  self.last_detected_ts = time.time()
32
  self.wait_t = 0.4 # 400 ms
33
 
 
36
  if sig:
37
  ts = time.time()
38
  if ts - self.last_detected_ts > self.wait_t:
39
+ with self._lock:
40
+ if self._thread is None:
41
+ self._thread = Thread(target=self._t_sound, daemon=True)
42
+ self._thread.start()
43
  self.last_detected_ts = ts
44
+
45
+ def _t_sound(self):
46
+ playsound(self._sound)
47
+ with self._lock:
48
+ self._thread = None