Spaces:
Build error
Build error
File size: 5,236 Bytes
35cdf83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from abc import ABC, abstractmethod
import time
from pathlib import Path
from pycoral.utils import edgetpu
import numpy as np
# Abstract interface for developers:
class Detector(ABC):
def __init__(self, threshold=None):
"""
If implementing __init__() in your subclass, it must take threshold as a keyword argument.
This is the value of the threshold that the user can set in the Portiloop GUI.
Caution: even if you don't need this manual threshold in your application,
your implementation of __init__() still needs to have this keyword argument.
"""
self.threshold = threshold
@abstractmethod
def detect(self, datapoints):
"""
Takes datapoints as input and outputs a detection signal.
Args:
datapoints: list of lists of n channels: may contain several datapoints.
A datapoint is a list of n floats, 1 for each channel.
In the current version of Portiloop, there is always only one datapoint per datapoints list.
Returns:
signal: Object: output detection signal (for instance, the output of a neural network);
this output signal is the input of the Stimulator.stimulate method.
If you don't mean to use a Stimulator, you can simply return None.
"""
raise NotImplementedError
# Example implementation for sleep spindles:
DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
# print(DEFAULT_MODEL_PATH)
class SleepSpindleRealTimeDetector(Detector):
def __init__(self, threshold=0.5, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
self.verbose = verbose
self.channel = channel
self.num_models_parallel = num_models_parallel
self.interpreters = []
for i in range(self.num_models_parallel):
self.interpreters.append(edgetpu.make_interpreter(model_path))
self.interpreters[i].allocate_tensors()
self.interpreter_counter = 0
self.input_details = self.interpreters[0].get_input_details()
self.output_details = self.interpreters[0].get_output_details()
self.buffer = []
self.seq_stride = seq_stride
self.window_size = window_size
self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * (i + 1)) for i in range(self.num_models_parallel)]
for idx in reversed(range(1, len(self.stride_counters))):
self.stride_counters[idx] -= self.stride_counters[idx-1]
assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
self.current_stride_counter = self.stride_counters[0] - 1
super().__init__(threshold)
def detect(self, datapoints):
res = []
for inp in datapoints:
result = self.add_datapoint(inp)
if result is not None:
res.append(result >= self.threshold)
return res
def add_datapoint(self, input_float):
input_float = input_float[self.channel - 1]
result = None
self.buffer.append(input_float)
if len(self.buffer) > self.window_size:
self.buffer = self.buffer[1:]
self.current_stride_counter += 1
if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
result = self.call_model(self.interpreter_counter, self.buffer)
self.interpreter_counter += 1
self.interpreter_counter %= self.num_models_parallel
self.current_stride_counter = 0
return result
def call_model(self, idx, input_float=None):
if input_float is None:
# For debugging purposes
input_shape = self.input_details[0]['shape']
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
else:
# Convert float input to Int
input_scale, input_zero_point = self.input_details[0]["quantization"]
input = np.asarray(input_float) / input_scale + input_zero_point
input = input.astype(self.input_details[0]["dtype"])
input = input.reshape((1, 1, -1))
# FIXME: bad sequence length: 50 instead of 1:
# self.interpreters[idx].set_tensor(self.input_details[0]['index'], input)
#
# if self.verbose:
# start_time = time.time()
#
# self.interpreters[idx].invoke()
#
# if self.verbose:
# end_time = time.time()
# output = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
# output_scale, output_zero_point = self.input_details[0]["quantization"]
# output = float(output - output_zero_point) * output_scale
output = np.random.uniform() # FIXME: remove
if self.verbose:
print(f"Computed output {output} in {end_time - start_time} seconds")
return output
|