MiloSobral's picture
Finished setting up the demo and fixed my git stupidity
2cb7306
raw
history blame
4.82 kB
import numpy as np
import pyxdf
from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012
from scipy.signal import butter, filtfilt, iirnotch, detrend
import time
from portiloop.src.stimulation import Stimulator
STREAM_NAMES = {
'filtered_data': 'Portiloop Filtered',
'raw_data': 'Portiloop Raw Data',
'stimuli': 'Portiloop_stimuli'
}
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
def __init__(self):
self.last_detected_ts = time.time()
self.wait_t = 0.4 # 400 ms
self.delayer = None
def stimulate(self, detection_signal):
stim = False
for sig in detection_signal:
# We detect a stimulation
if sig:
# Record time of stimulation
ts = time.time()
# Check if time since last stimulation is long enough
if ts - self.last_detected_ts > self.wait_t:
if self.delayer is not None:
# If we have a delayer, notify it
self.delayer.detected()
stim = True
self.last_detected_ts = ts
return stim
def add_delayer(self, delayer):
self.delayer = delayer
self.delayer.stimulate = lambda: True
def xdf2array(xdf_path, channel):
xdf_data, _ = pyxdf.load_xdf(xdf_path)
# Load all streams given their names
filtered_stream, raw_stream, markers = None, None, None
for stream in xdf_data:
# print(stream['info']['name'])
if stream['info']['name'][0] == STREAM_NAMES['filtered_data']:
filtered_stream = stream
elif stream['info']['name'][0] == STREAM_NAMES['raw_data']:
raw_stream = stream
elif stream['info']['name'][0] == STREAM_NAMES['stimuli']:
markers = stream
if filtered_stream is None or raw_stream is None:
raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording")
# Add all samples from raw and filtered signals
csv_list = []
diffs = []
shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]),
int(raw_stream['footer']['info']['sample_count'][0]))
for i in range(shortest_stream):
if markers is not None:
datapoint = [filtered_stream['time_stamps'][i],
float(filtered_stream['time_series'][i, channel-1]),
raw_stream['time_series'][i, channel-1],
0]
else:
datapoint = [filtered_stream['time_stamps'][i],
float(filtered_stream['time_series'][i, channel-1]),
raw_stream['time_series'][i, channel-1]]
diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i]))
csv_list.append(datapoint)
# Add markers
columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"]
if markers is not None:
columns.append("online_stimulations_portiloop")
for time_stamp in markers['time_stamps']:
new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin()
csv_list[new_index][3] = 1
return np.array(csv_list), columns
def offline_detect(method, data, timesteps, freq):
# Get the spindle data from the offline methods
time = np.arange(0, len(data)) / freq
if method == "Lacourse":
detector = DetectSpindle(method='Lacourse2018')
spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
elif method == "Wamsley":
detector = DetectSpindle(method='Wamsley2012')
spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
else:
raise ValueError("Invalid method")
# Convert the spindle data to a numpy array
spindle_result = np.zeros(data.shape)
for spindle in spindles:
start = spindle["start"]
end = spindle["end"]
# Find index of timestep closest to start and end
start_index = np.argmin(np.abs(timesteps - start))
end_index = np.argmin(np.abs(timesteps - end))
spindle_result[start_index:end_index] = 1
return spindle_result
def offline_filter(signal, freq):
# Notch filter
f0 = 60.0 # Frequency to be removed from signal (Hz)
Q = 100.0 # Quality factor
b, a = iirnotch(f0, Q, freq)
signal = filtfilt(b, a, signal)
# Bandpass filter
lowcut = 0.5
highcut = 40.0
order = 4
b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass')
signal = filtfilt(b, a, signal)
# Detrend the signal
signal = detrend(signal)
return signal