File size: 6,633 Bytes
35cdf83
 
 
fcba0a9
35cdf83
fcba0a9
 
7d40d1a
 
35cdf83
 
 
 
 
 
 
35af352
35cdf83
35af352
35cdf83
 
35af352
35cdf83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a64c5cc
35cdf83
 
 
35af352
 
 
 
 
 
 
 
35cdf83
 
 
 
 
 
7d40d1a
 
 
 
35cdf83
 
 
 
 
 
 
 
 
 
 
 
 
 
ba200d7
 
35cdf83
 
 
35af352
35cdf83
 
7d40d1a
 
 
35af352
 
 
7d40d1a
35cdf83
 
 
 
 
 
 
 
45a88e4
 
 
35cdf83
 
45a88e4
35cdf83
 
45a88e4
35cdf83
 
 
45a88e4
 
35cdf83
 
 
 
ba200d7
45a88e4
 
 
ba200d7
45a88e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba200d7
45a88e4
 
 
 
ba200d7
45a88e4
 
 
 
 
 
 
35cdf83
45a88e4
da5fc70
45a88e4
35cdf83
45a88e4
35cdf83
45a88e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from abc import ABC, abstractmethod
import time
from pathlib import Path
from portiloop.src import ADS

if ADS:
    from pycoral.utils import edgetpu
else:
    import tensorflow as tf
import numpy as np


# Abstract interface for developers:

class Detector(ABC):
    
    def __init__(self, threshold=None, channel=None):
        """
        Mandatory arguments are from the in the Portiloop GUI.
        """
        self.threshold = threshold
        self.channel = channel

    @abstractmethod
    def detect(self, datapoints):
        """
        Takes datapoints as input and outputs a detection signal.
        
        Args:
            datapoints: list of lists of n channels: may contain several datapoints.
                A datapoint is a list of n floats, 1 for each channel.
                In the current version of Portiloop, there is always only one datapoint per datapoints list.
        
        Returns:
            signal: Object: output detection signal (for instance, the output of a neural network);
                this output signal is the input of the Stimulator.stimulate method.
                If you don't mean to use a Stimulator, you can simply return None.
        """
        raise NotImplementedError


# Example implementation for sleep spindles:
        
DEFAULT_MODEL_PATH = str(Path(__file__).parent.parent / "models/portiloop_model_quant.tflite")
# print(DEFAULT_MODEL_PATH)

class SleepSpindleRealTimeDetector(Detector):
    def __init__(self,
                 threshold=0.5,
                 num_models_parallel=8,
                 window_size=54,
                 seq_stride=42,
                 model_path=None,
                 verbose=False,
                 channel=2):
        model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
        self.verbose = verbose
        self.num_models_parallel = num_models_parallel
        
        self.interpreters = []
        for i in range(self.num_models_parallel):
            if ADS:
                self.interpreters.append(edgetpu.make_interpreter(model_path))
            else:
                self.interpreters.append(tf.lite.Interpreter(model_path=model_path))
            self.interpreters[i].allocate_tensors()
        self.interpreter_counter = 0
        
        self.input_details = self.interpreters[0].get_input_details()
        self.output_details = self.interpreters[0].get_output_details()
        
        self.buffer = []
        self.seq_stride = seq_stride
        self.window_size = window_size 
        
        self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * (i + 1)) for i in range(self.num_models_parallel)]
        for idx in reversed(range(1, len(self.stride_counters))):
            self.stride_counters[idx] -= self.stride_counters[idx-1]
        assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
        
        self.h = [np.zeros((1, 7), dtype=np.int8) for _ in range(self.num_models_parallel)]
            
        self.current_stride_counter = self.stride_counters[0] - 1
        
        super().__init__(threshold, channel)

    def detect(self, datapoints):
        """
        Takes datapoints as input and outputs a detection signal.
        datapoints is a list of lists of n channels: may contain several datapoints.
        
        The output signal is a list of tuples (is_spindle, is_train_of_spindles).
        
        """
        res = []
        for inp in datapoints:
            result = self.add_datapoint(inp)
            if result is not None:
                res.append(result >= self.threshold)
        return res

    def add_datapoint(self, input_float):
        '''
        Add one datapoint to the buffer
        '''
        input_float = input_float[self.channel - 1]
        result = None
        # Add to current buffer
        self.buffer.append(input_float)
        if len(self.buffer) > self.window_size:
            # Remove the end of the buffer
            self.buffer = self.buffer[1:]
            self.current_stride_counter += 1
            if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
                # If we have reached the next window size, we send the current buffer to the inference function and update the hidden state
                result, self.h[self.interpreter_counter] = self.forward_tflite(self.interpreter_counter, self.buffer, self.h[self.interpreter_counter])
                self.interpreter_counter += 1
                self.interpreter_counter %= self.num_models_parallel
                self.current_stride_counter = 0
        return result
        
    def forward_tflite(self, idx, input_x, input_h):
        input_details = self.interpreters[idx].get_input_details()
        output_details = self.interpreters[idx].get_output_details()
        
        # convert input to int 
        input_scale, input_zero_point = input_details[1]["quantization"]
        input_x = np.asarray(input_x) / input_scale + input_zero_point
        input_data_x = input_x.astype(input_details[1]["dtype"])
        input_data_x = np.expand_dims(input_data_x, (0, 1))

        # input_scale, input_zero_point = input_details[0]["quantization"]
        # input = np.asarray(input) / input_scale + input_zero_point

        # Test the model on random input data.
        input_shape_h = input_details[0]['shape']
        input_shape_x = input_details[1]['shape']

        # input_data_h = np.array(np.random.random_sample(input_shape_h), dtype=np.int8)
        # input_data_x = np.array(np.random.random_sample(input_shape_x), dtype=np.int8)
        self.interpreters[idx].set_tensor(input_details[0]['index'], input_h)
        self.interpreters[idx].set_tensor(input_details[1]['index'], input_data_x)
        
        if self.verbose:
            start_time = time.time()

        self.interpreters[idx].invoke()
        
        if self.verbose:
            end_time = time.time()

        # The function `get_tensor()` returns a copy of the tensor data.
        # Use `tensor()` in order to get a pointer to the tensor.
        output_data_h = self.interpreters[idx].get_tensor(output_details[0]['index'])
        output_data_y = self.interpreters[idx].get_tensor(output_details[1]['index'])

        output_scale, output_zero_point = output_details[1]["quantization"]
        output_data_y = (int(output_data_y) - output_zero_point) * output_scale
        
        if self.verbose:
            print(f"Computed output {output_data_y} in {end_time - start_time} seconds")

        return output_data_y, output_data_h