Spaces:
Sleeping
Sleeping
File size: 4,822 Bytes
2cb7306 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import numpy as np
import pyxdf
from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012
from scipy.signal import butter, filtfilt, iirnotch, detrend
import time
from portiloop.src.stimulation import Stimulator
STREAM_NAMES = {
'filtered_data': 'Portiloop Filtered',
'raw_data': 'Portiloop Raw Data',
'stimuli': 'Portiloop_stimuli'
}
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
def __init__(self):
self.last_detected_ts = time.time()
self.wait_t = 0.4 # 400 ms
self.delayer = None
def stimulate(self, detection_signal):
stim = False
for sig in detection_signal:
# We detect a stimulation
if sig:
# Record time of stimulation
ts = time.time()
# Check if time since last stimulation is long enough
if ts - self.last_detected_ts > self.wait_t:
if self.delayer is not None:
# If we have a delayer, notify it
self.delayer.detected()
stim = True
self.last_detected_ts = ts
return stim
def add_delayer(self, delayer):
self.delayer = delayer
self.delayer.stimulate = lambda: True
def xdf2array(xdf_path, channel):
xdf_data, _ = pyxdf.load_xdf(xdf_path)
# Load all streams given their names
filtered_stream, raw_stream, markers = None, None, None
for stream in xdf_data:
# print(stream['info']['name'])
if stream['info']['name'][0] == STREAM_NAMES['filtered_data']:
filtered_stream = stream
elif stream['info']['name'][0] == STREAM_NAMES['raw_data']:
raw_stream = stream
elif stream['info']['name'][0] == STREAM_NAMES['stimuli']:
markers = stream
if filtered_stream is None or raw_stream is None:
raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording")
# Add all samples from raw and filtered signals
csv_list = []
diffs = []
shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]),
int(raw_stream['footer']['info']['sample_count'][0]))
for i in range(shortest_stream):
if markers is not None:
datapoint = [filtered_stream['time_stamps'][i],
float(filtered_stream['time_series'][i, channel-1]),
raw_stream['time_series'][i, channel-1],
0]
else:
datapoint = [filtered_stream['time_stamps'][i],
float(filtered_stream['time_series'][i, channel-1]),
raw_stream['time_series'][i, channel-1]]
diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i]))
csv_list.append(datapoint)
# Add markers
columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"]
if markers is not None:
columns.append("online_stimulations_portiloop")
for time_stamp in markers['time_stamps']:
new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin()
csv_list[new_index][3] = 1
return np.array(csv_list), columns
def offline_detect(method, data, timesteps, freq):
# Get the spindle data from the offline methods
time = np.arange(0, len(data)) / freq
if method == "Lacourse":
detector = DetectSpindle(method='Lacourse2018')
spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
elif method == "Wamsley":
detector = DetectSpindle(method='Wamsley2012')
spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
else:
raise ValueError("Invalid method")
# Convert the spindle data to a numpy array
spindle_result = np.zeros(data.shape)
for spindle in spindles:
start = spindle["start"]
end = spindle["end"]
# Find index of timestep closest to start and end
start_index = np.argmin(np.abs(timesteps - start))
end_index = np.argmin(np.abs(timesteps - end))
spindle_result[start_index:end_index] = 1
return spindle_result
def offline_filter(signal, freq):
# Notch filter
f0 = 60.0 # Frequency to be removed from signal (Hz)
Q = 100.0 # Quality factor
b, a = iirnotch(f0, Q, freq)
signal = filtfilt(b, a, signal)
# Bandpass filter
lowcut = 0.5
highcut = 40.0
order = 4
b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass')
signal = filtfilt(b, a, signal)
# Detrend the signal
signal = detrend(signal)
return signal
|