Milo Sobral commited on
Commit
e614a99
·
1 Parent(s): 880213e

added nice library to capture, save and display results

Browse files
Files changed (2) hide show
  1. src/capture.py +335 -2
  2. src/notebooks/tests.ipynb +1 -3
src/capture.py CHANGED
@@ -1,2 +1,335 @@
1
- # Launch a capture of the electrodes
2
- # TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import threading
3
+ import time
4
+ import logging
5
+ import random
6
+ import Queue
7
+
8
+ from frontend import Frontend
9
+ from leds import LEDs, Color
10
+ from portilooplot.jupyter_plot import ProgressPlot
11
+
12
+ import ctypes
13
+ import numpy as np
14
+ import json
15
+
16
+
17
+ class Datapoint:
18
+ '''
19
+ Class to represent a single reading
20
+ '''
21
+ def __init__(self, raw_datapoint, temperature=[], num_channels=8):
22
+ # Initialize necessary data structures
23
+ self.num_channels = num_channels
24
+ self.reading = np.array(num_channels, dtype=float)
25
+
26
+ assert len(temperature) <= len(raw_datapoint), "Temperature array length must be lesser or equal to number of channels"
27
+ self.temperature = temperature
28
+
29
+ self._filter_datapoint(raw_datapoint, Datapoint.filter_2scomplement)
30
+
31
+ def _filter_datapoint(self, raw_datapoint, filter):
32
+ # Filter one datapoint with the given filter
33
+ assert len(raw_datapoint) == self.num_channels, "Datapoint dimensions do not match channel number"
34
+ for idx, point in enumerate(raw_datapoint):
35
+ # If the given index is a temperature, add that filter to get correct reading
36
+ if idx in self.temperature:
37
+ filter = lambda x : Datapoint.filter_temp(filter(x))
38
+ self.reading[idx] = filter(point)
39
+
40
+ def get_datapoint(self):
41
+ '''
42
+ Get readings of all channels in numpy array format
43
+ '''
44
+ return self.reading
45
+
46
+ def get_channel(self, channel_idx):
47
+ '''
48
+ Reading at the channel specified by channel_idx (0-7)
49
+ Returns a tuple (value(float), temperature(boolean)) --> temperature is True if channel is a temperature
50
+ '''
51
+ assert 0 <= channel_idx < self.num_channels - 1, "Channel index must be in range [0 - channel_num-1]"
52
+ return self.reading[channel_idx], (channel_idx in self.temperature)
53
+
54
+ def get_portilooplot(self):
55
+ '''
56
+ Returns the portilooplot ready version of the Datapoint
57
+ '''
58
+ return [[point] for point in self.reading]
59
+
60
+ @staticmethod
61
+ def filter_2scomplement(value):
62
+ '''
63
+ Convert from binary two's complement to binary int
64
+ '''
65
+ if (value & (1 << 23)) != 0:
66
+ value = value - (1 << 24)
67
+ return Datapoint.filter_23(value)
68
+
69
+ @staticmethod
70
+ def filter_23(value):
71
+ '''
72
+ Convert from binary int to normal int
73
+ '''
74
+ return (value * 4.5) / (2**23 - 1) # 23 because 1 bit is lost for sign
75
+
76
+ @staticmethod
77
+ def filter_temp(value):
78
+ '''
79
+ Convert from voltage reading to temperature reading in Celcius
80
+ '''
81
+ return int((value * 1000000.0 - 145300.0) / 490.0 + 25.0)
82
+
83
+
84
+ class CaptureThread(threading.Thread):
85
+ '''
86
+ Producer thread which reads from the EEG device. Thread does not process the data
87
+ '''
88
+
89
+ def __init__(self, q, freq=250, timeout=None, target=None, name=None):
90
+ super(CaptureThread, self).__init__()
91
+ self.timeout = timeout
92
+ self.target = target
93
+ self.name = name
94
+ self.q = q
95
+ self.freq = freq
96
+ self.frontend = Frontend()
97
+ self.leds = LEDs()
98
+
99
+ def run(self):
100
+ '''
101
+ Run the data capture continuously or until timeout
102
+ '''
103
+ self.init_checks()
104
+ start_time = time.time()
105
+ prev_ts = time.time()
106
+ ts_len = 1 / self.freq
107
+
108
+ while True:
109
+ if not self.q.full():
110
+ # Wait for frontend and for minimum time limit
111
+ while not self.frontend.is_ready() and not time.time() - prev_ts >= ts_len:
112
+ pass
113
+ prev_ts = time.time()
114
+
115
+ # Read values and add to q
116
+ values = self.frontend.read()
117
+ self.q.put(values)
118
+
119
+ # Wait until reading is fully ompleted
120
+ while self.frontend.is_ready():
121
+ pass
122
+
123
+ # Check for timeout
124
+ if time.time() - start_time > self.timeout:
125
+ break
126
+ return
127
+
128
+ def init_checks(self):
129
+ '''
130
+ Run Initial threads to the registers to make sure we can start reading
131
+ '''
132
+ data = self.frontend.read_regs(0x00, len(FRONTEND_CONFIG))
133
+ assert data == FRONTEND_CONFIG, f"Wrong config: {data} vs {FRONTEND_CONFIG}"
134
+ self.frontend.start()
135
+ print("EEG Frontend configured")
136
+ self.leds.led2(Color.PURPLE)
137
+ while not self.frontend.is_ready():
138
+ pass
139
+ print("Ready for data")
140
+
141
+
142
+ class FilterThread(threading.Thread):
143
+ def __init__(self, q, target=None, name=None, temperature=[], num_channels=8):
144
+ '''
145
+ Consume raw datapoint from the Capture points, filter them, put resulting datapoint objects into all queues in list
146
+ '''
147
+ super(FilterThread, self).__init__()
148
+ self.target = target
149
+ self.name = name
150
+
151
+ # Initialize thread safe datastructures for both consuming and producing
152
+ self.raw_q = q
153
+ self.qs = []
154
+
155
+ # Initilialize settings for the filtering
156
+ self.temperature = temperature
157
+ self.num_channels = num_channels
158
+
159
+ return
160
+
161
+ def run(self):
162
+ while True:
163
+ raw_data = None
164
+ # Get an item from CaptureThread
165
+ if not self.raw_q.empty():
166
+ raw_data = self.raw_q.get()
167
+ assert raw_data is not None, "Got a None item from CaptureThread in FilterThread"
168
+
169
+ datapoint = Datapoint(raw_data, )
170
+
171
+ # Put Item to all ConsumerThreads
172
+ for q in self.qs:
173
+ if not q.full():
174
+ q.put(item)
175
+ return
176
+
177
+ def add_q(self, q):
178
+ '''
179
+ Add a Queue to the list of queues where filtered values get added
180
+ '''
181
+ self.qs.append(q)
182
+
183
+ def remove_q(self, q):
184
+ '''
185
+ Remove a queue from the list
186
+ '''
187
+ self.qs.remove(q)
188
+
189
+ def update_settings(self, temperature=None, num_channels=None):
190
+ '''
191
+ Update Settings on the go
192
+ '''
193
+ if self.temperatures is not None:
194
+ self.temperature = temperature
195
+ if num_channels is not None:
196
+ self.num_channels = num_channels
197
+
198
+
199
+ class ConsumerThread(threading.Thread):
200
+ def __init__(self, q, target=None, name=None):
201
+ '''
202
+ Implemetns basic consumer logic, needs _consume_item() to be implemented
203
+ '''
204
+ super(ConsumerThread,self).__init__()
205
+ self.target = target
206
+ self.name = name
207
+ self.q = q
208
+
209
+ def run(self):
210
+ try:
211
+ while True:
212
+ item = None
213
+ if not self.q.empty():
214
+ item = self.q.get()
215
+
216
+ assert item is not None, "Got a None value from FilterThread in ConsumerThread"
217
+ self._consume_item(item)
218
+ except Exception:
219
+ self._on_exit()
220
+
221
+ def get_id(self):
222
+
223
+ # returns id of the respective thread
224
+ if hasattr(self, '_thread_id'):
225
+ return self._thread_id
226
+ for id, thread in threading._active.items():
227
+ if thread is self:
228
+ return id
229
+
230
+ def raise_exception(self):
231
+ thread_id = self.get_id()
232
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id,
233
+ ctypes.py_object(SystemExit))
234
+ if res > 1:
235
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, 0)
236
+ print('Exception raise failure')
237
+
238
+ @abstractmethod
239
+ def _consume_item(self, item):
240
+ raise NotImplementedError("_consume_item needs to be implemented in Subclass of ConsumerThread")
241
+
242
+ @abstractmethod
243
+ def _on_exit(self):
244
+ raise NotImplementedError("_on_exit needs to be implemented in Subclass of ConsumerThread")
245
+
246
+
247
+ class DisplayThread(ConsumerThread):
248
+ def __init__(self, q, max_window_len=100, num_channel=8, target=None, name=None):
249
+ super().__init__(q, target, name)
250
+ self.pp = ProgressPlot(plot_names=[f"channel#{i+1}" for i in range(num_channel)], max_window_len=max_window_len)
251
+
252
+ def _consume_item(self, item):
253
+ self.pp.update(item.get_portilooplot())
254
+
255
+ def _on_exit(self):
256
+ self.pp.finalize()
257
+
258
+
259
+ class SaveThread(ConsumerThread):
260
+ def __init__(self, q, default_loc='', target=None, name=None):
261
+ super().__init__(q, target, name)
262
+ self.save = []
263
+ self.default_loc = default_loc
264
+
265
+ def to_disk(self, destination):
266
+ print('Saving Method is not yet implemented')
267
+ pass
268
+
269
+ def _consume_item(self, item):
270
+ self.save.append(item.get_datapoint().to_list())
271
+
272
+ def _on_exit(self):
273
+ self.to_disk(self.default_loc)
274
+
275
+
276
+ class Capture:
277
+ def __init__(self, viz=True, record=True):
278
+ self.viz = viz
279
+ self.record = record
280
+
281
+ # Initialize data structures for capture and filtering
282
+ raw_q = Queue.Queue()
283
+ self.capture_thread = CaptureThread(raw_q)
284
+ self.filter_thread = FilterThread(raw_q)
285
+
286
+ # Declare data structures for viz and record functionality
287
+ self.viz_q = None
288
+ self.record_q = None
289
+ self.viz_thread = None
290
+ self.record_thread = None
291
+
292
+ self.capture_thread.start()
293
+ self.filter_thread.start()
294
+
295
+ if viz:
296
+ self.start_viz()
297
+
298
+ if record:
299
+ self.start_record()
300
+
301
+ def start_viz(self):
302
+ self.viz_q = Queue.Queue()
303
+ self.viz_thread = DisplayThread(self.viz_q)
304
+ self.filter_thread.add_q(self.viz_q)
305
+ self.viz_thread.start()
306
+
307
+ def stop_viz(self):
308
+ self.filter_thread.remove_q(self.viz_q)
309
+ self.viz_q = None
310
+ self.viz_thread.raise_exception()
311
+
312
+ def start_record(self):
313
+ self.record_q = Queue.Queue()
314
+ self.record_thread = SaveThread(self.record_q)
315
+ self.filter_thread.add_q(self.record_q)
316
+ self.record_thread.start()
317
+
318
+ def stop_record(self):
319
+ self.filter_thread.remove_q(self.viz_q)
320
+ self.viz_q = None
321
+ self.record_thread.raise_exception()
322
+
323
+ def save(self, destination=None):
324
+ if destination is not None:
325
+ self.record_thread.save(destination)
326
+ else:
327
+ self.record_thread.save()
328
+
329
+
330
+
331
+
332
+
333
+ if __name__ == "__main__":
334
+ # TODO: Argparse this
335
+ pass
src/notebooks/tests.ipynb CHANGED
@@ -10066,9 +10066,7 @@
10066
  },
10067
  {
10068
  "data": {
10069
- "application/javascript": [
10070
- "window.appendLearningCurve([{\"x\": 7047.0, \"y\": {\"channel#1\": {\"line-1\": 2.493805050111419}, \"channel#2\": {\"line-1\": 0.04119283451948577}, \"channel#3\": {\"line-1\": 0.04224318769492957}, \"channel#4\": {\"line-1\": 0.040438597254585894}, \"channel#5\": {\"line-1\": 2.4902720439758355}, \"channel#6\": {\"line-1\": 2.4909066547044105}, \"channel#7\": {\"line-1\": -0.0023152830976585267}, \"channel#8\": {\"line-1\": 53.0}}}]);"
10071
- ],
10072
  "text/plain": [
10073
  "<IPython.core.display.Javascript object>"
10074
  ]
 
10066
  },
10067
  {
10068
  "data": {
10069
+ "application/javascript": "window.appendLearningCurve([{\"x\": 7047.0, \"y\": {\"channel#1\": {\"line-1\": 2.493805050111419}, \"channel#2\": {\"line-1\": 0.04119283451948577}, \"channel#3\": {\"line-1\": 0.04224318769492957}, \"channel#4\": {\"line-1\": 0.040438597254585894}, \"channel#5\": {\"line-1\": 2.4902720439758355}, \"channel#6\": {\"line-1\": 2.4909066547044105}, \"channel#7\": {\"line-1\": -0.0023152830976585267}, \"channel#8\": {\"line-1\": 53.0}}}]);",
 
 
10070
  "text/plain": [
10071
  "<IPython.core.display.Javascript object>"
10072
  ]