File size: 4,822 Bytes
2cb7306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
import pyxdf
from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012
from scipy.signal import butter, filtfilt, iirnotch, detrend
import time
from portiloop.src.stimulation import Stimulator


STREAM_NAMES = {
    'filtered_data': 'Portiloop Filtered',
    'raw_data': 'Portiloop Raw Data',
    'stimuli': 'Portiloop_stimuli'
}


class OfflineSleepSpindleRealTimeStimulator(Stimulator):
    def __init__(self):
        self.last_detected_ts = time.time()
        self.wait_t = 0.4  # 400 ms
        self.delayer = None
    
    def stimulate(self, detection_signal):
        stim = False
        for sig in detection_signal:
            # We detect a stimulation
            if sig:
                # Record time of stimulation
                ts = time.time()
                
                # Check if time since last stimulation is long enough
                if ts - self.last_detected_ts > self.wait_t:
                    if self.delayer is not None:
                        # If we have a delayer, notify it
                        self.delayer.detected()
                    stim = True

                self.last_detected_ts = ts
        return stim

    def add_delayer(self, delayer):
        self.delayer = delayer
        self.delayer.stimulate = lambda: True

def xdf2array(xdf_path, channel):
    xdf_data, _ = pyxdf.load_xdf(xdf_path)

    # Load all streams given their names
    filtered_stream, raw_stream, markers = None, None, None
    for stream in xdf_data:
        # print(stream['info']['name'])
        if stream['info']['name'][0] == STREAM_NAMES['filtered_data']:
            filtered_stream = stream
        elif stream['info']['name'][0] == STREAM_NAMES['raw_data']:
            raw_stream = stream
        elif stream['info']['name'][0] == STREAM_NAMES['stimuli']:
            markers = stream
    
    if filtered_stream is None or raw_stream is None:
        raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording")

    # Add all samples from raw and filtered signals
    csv_list = []
    diffs = []
    shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]),
                          int(raw_stream['footer']['info']['sample_count'][0]))
    for i in range(shortest_stream):
        if markers is not None:
            datapoint = [filtered_stream['time_stamps'][i], 
                        float(filtered_stream['time_series'][i, channel-1]), 
                        raw_stream['time_series'][i, channel-1], 
                        0]
        else:
            datapoint = [filtered_stream['time_stamps'][i], 
                        float(filtered_stream['time_series'][i, channel-1]), 
                        raw_stream['time_series'][i, channel-1]]
        diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i]))
        csv_list.append(datapoint)

    # Add markers
    columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"]
    if markers is not None:
        columns.append("online_stimulations_portiloop")
        for time_stamp in markers['time_stamps']:
            new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin()
            csv_list[new_index][3] = 1
    
    return np.array(csv_list), columns
    

def offline_detect(method, data, timesteps, freq):
    # Get the spindle data from the offline methods
    time = np.arange(0, len(data)) / freq
    if method == "Lacourse":
        detector = DetectSpindle(method='Lacourse2018')
        spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
    elif method == "Wamsley":
        detector = DetectSpindle(method='Wamsley2012')
        spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
    else:
        raise ValueError("Invalid method")

    # Convert the spindle data to a numpy array
    spindle_result = np.zeros(data.shape)
    for spindle in spindles:
        start = spindle["start"]
        end = spindle["end"]
        # Find index of timestep closest to start and end
        start_index = np.argmin(np.abs(timesteps - start))
        end_index = np.argmin(np.abs(timesteps - end))
        spindle_result[start_index:end_index] = 1
    return spindle_result


def offline_filter(signal, freq):

    # Notch filter
    f0 = 60.0  # Frequency to be removed from signal (Hz)
    Q = 100.0  # Quality factor
    b, a = iirnotch(f0, Q, freq)
    signal = filtfilt(b, a, signal)

    # Bandpass filter
    lowcut = 0.5
    highcut = 40.0
    order = 4
    b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass')
    signal = filtfilt(b, a, signal)

    # Detrend the signal
    signal = detrend(signal)

    return signal