Spaces:
Sleeping
Sleeping
import numpy as np | |
import pyxdf | |
from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012 | |
from scipy.signal import butter, filtfilt, iirnotch, detrend | |
import time | |
from portiloop.src.stimulation import Stimulator | |
STREAM_NAMES = { | |
'filtered_data': 'Portiloop Filtered', | |
'raw_data': 'Portiloop Raw Data', | |
'stimuli': 'Portiloop_stimuli' | |
} | |
class OfflineSleepSpindleRealTimeStimulator(Stimulator): | |
def __init__(self): | |
self.last_detected_ts = time.time() | |
self.wait_t = 0.4 # 400 ms | |
self.delayer = None | |
def stimulate(self, detection_signal): | |
stim = False | |
for sig in detection_signal: | |
# We detect a stimulation | |
if sig: | |
# Record time of stimulation | |
ts = time.time() | |
# Check if time since last stimulation is long enough | |
if ts - self.last_detected_ts > self.wait_t: | |
if self.delayer is not None: | |
# If we have a delayer, notify it | |
self.delayer.detected() | |
stim = True | |
self.last_detected_ts = ts | |
return stim | |
def add_delayer(self, delayer): | |
self.delayer = delayer | |
self.delayer.stimulate = lambda: True | |
def xdf2array(xdf_path, channel): | |
xdf_data, _ = pyxdf.load_xdf(xdf_path) | |
# Load all streams given their names | |
filtered_stream, raw_stream, markers = None, None, None | |
for stream in xdf_data: | |
# print(stream['info']['name']) | |
if stream['info']['name'][0] == STREAM_NAMES['filtered_data']: | |
filtered_stream = stream | |
elif stream['info']['name'][0] == STREAM_NAMES['raw_data']: | |
raw_stream = stream | |
elif stream['info']['name'][0] == STREAM_NAMES['stimuli']: | |
markers = stream | |
if filtered_stream is None or raw_stream is None: | |
raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording") | |
# Add all samples from raw and filtered signals | |
csv_list = [] | |
diffs = [] | |
shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]), | |
int(raw_stream['footer']['info']['sample_count'][0])) | |
for i in range(shortest_stream): | |
if markers is not None: | |
datapoint = [filtered_stream['time_stamps'][i], | |
float(filtered_stream['time_series'][i, channel-1]), | |
raw_stream['time_series'][i, channel-1], | |
0] | |
else: | |
datapoint = [filtered_stream['time_stamps'][i], | |
float(filtered_stream['time_series'][i, channel-1]), | |
raw_stream['time_series'][i, channel-1]] | |
diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i])) | |
csv_list.append(datapoint) | |
# Add markers | |
columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"] | |
if markers is not None: | |
columns.append("online_stimulations_portiloop") | |
for time_stamp in markers['time_stamps']: | |
new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin() | |
csv_list[new_index][3] = 1 | |
return np.array(csv_list), columns | |
def offline_detect(method, data, timesteps, freq): | |
# Get the spindle data from the offline methods | |
time = np.arange(0, len(data)) / freq | |
if method == "Lacourse": | |
detector = DetectSpindle(method='Lacourse2018') | |
spindles, _, _ = detect_Lacourse2018(data, freq, time, detector) | |
elif method == "Wamsley": | |
detector = DetectSpindle(method='Wamsley2012') | |
spindles, _, _ = detect_Wamsley2012(data, freq, time, detector) | |
else: | |
raise ValueError("Invalid method") | |
# Convert the spindle data to a numpy array | |
spindle_result = np.zeros(data.shape) | |
for spindle in spindles: | |
start = spindle["start"] | |
end = spindle["end"] | |
# Find index of timestep closest to start and end | |
start_index = np.argmin(np.abs(timesteps - start)) | |
end_index = np.argmin(np.abs(timesteps - end)) | |
spindle_result[start_index:end_index] = 1 | |
return spindle_result | |
def offline_filter(signal, freq): | |
# Notch filter | |
f0 = 60.0 # Frequency to be removed from signal (Hz) | |
Q = 100.0 # Quality factor | |
b, a = iirnotch(f0, Q, freq) | |
signal = filtfilt(b, a, signal) | |
# Bandpass filter | |
lowcut = 0.5 | |
highcut = 40.0 | |
order = 4 | |
b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass') | |
signal = filtfilt(b, a, signal) | |
# Detrend the signal | |
signal = detrend(signal) | |
return signal | |