Spaces:
Sleeping
Sleeping
import numpy as np | |
import pyxdf | |
from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012 | |
from scipy.signal import butter, filtfilt, iirnotch, detrend | |
import time | |
from portiloop.src.stimulation import Stimulator | |
STREAM_NAMES = { | |
'filtered_data': 'Portiloop Filtered', | |
'raw_data': 'Portiloop Raw Data', | |
'stimuli': 'Portiloop_stimuli' | |
} | |
def sleep_stage(data, threshold=150, group_size=2): | |
"""Sleep stage approximation using a threshold and a group size. | |
Returns a numpy array containing all indices in the input data which CAN be used for offline detection. | |
These indices can then be used to reconstruct the signal from the original data. | |
""" | |
# Find all indexes where the signal is above or below the threshold | |
above = np.where(data > threshold) | |
below = np.where(data < -threshold) | |
indices = np.concatenate((above, below), axis=1)[0] | |
indices = np.sort(indices) | |
# Get all the indices where the difference between two consecutive indices is larger than 100 | |
groups = np.where(np.diff(indices) <= group_size)[0] + 1 | |
# Get the important indices | |
important_indices = indices[groups] | |
# Get all the indices between the important indices | |
group_filler = [np.arange(indices[groups[n] - 1] + 1, index) for n, index in enumerate(important_indices)] | |
# Create flat array from fillers | |
group_filler = np.concatenate(group_filler) | |
# Append all group fillers to the indices | |
masked_indices = np.sort(np.concatenate((indices, group_filler))) | |
unmasked_indices = np.setdiff1d(np.arange(len(data)), masked_indices) | |
return unmasked_indices | |
class OfflineSleepSpindleRealTimeStimulator(Stimulator): | |
def __init__(self): | |
self.last_detected_ts = time.time() | |
self.wait_t = 0.4 # 400 ms | |
self.wait_timesteps = int(self.wait_t * 250) | |
self.delayer = None | |
self.index = 0 | |
def stimulate(self, detection_signal): | |
self.index += 1 | |
stim = False | |
for sig in detection_signal: | |
# We detect a stimulation | |
if sig: | |
# Record time of stimulation | |
ts = self.index | |
# Check if time since last stimulation is long enough | |
if ts - self.last_detected_ts > self.wait_timesteps: | |
if self.delayer is not None: | |
# If we have a delayer, notify it | |
self.delayer.detected() | |
stim = True | |
self.last_detected_ts = ts | |
return stim | |
def add_delayer(self, delayer): | |
self.delayer = delayer | |
self.delayer.stimulate = lambda: True | |
class OfflineSpindleTrainRealTimeStimulator(OfflineSleepSpindleRealTimeStimulator): | |
def __init__(self): | |
super().__init__() | |
self.max_spindle_train_t = 6.0 | |
def stimulate(self, detection_signal): | |
self.index += 1 | |
stim = False | |
for sig in detection_signal: | |
# We detect a stimulation | |
if sig: | |
# Record time of stimulation | |
ts = self.index | |
elapsed = ts - self.last_detected_ts | |
# Check if time since last stimulation is long enough | |
if self.wait_timesteps < elapsed < int(self.max_spindle_train_t * 250): | |
if self.delayer is not None: | |
# If we have a delayer, notify it | |
self.delayer.detected() | |
stim = True | |
self.last_detected_ts = ts | |
return stim | |
class OfflineIsolatedSpindleRealTimeStimulator(OfflineSpindleTrainRealTimeStimulator): | |
def stimulate(self, detection_signal): | |
self.index += 1 | |
stim = False | |
for sig in detection_signal: | |
# We detect a stimulation | |
if sig: | |
# Record time of stimulation | |
ts = self.index | |
elapsed = ts - self.last_detected_ts | |
# Check if time since last stimulation is long enough | |
if int(self.max_spindle_train_t * 250) < elapsed: | |
if self.delayer is not None: | |
# If we have a delayer, notify it | |
self.delayer.detected() | |
stim = True | |
self.last_detected_ts = ts | |
return stim | |
def xdf2array(xdf_path, channel): | |
xdf_data, _ = pyxdf.load_xdf(xdf_path) | |
# Load all streams given their names | |
filtered_stream, raw_stream, markers = None, None, None | |
for stream in xdf_data: | |
# print(stream['info']['name']) | |
if stream['info']['name'][0] == STREAM_NAMES['filtered_data']: | |
filtered_stream = stream | |
elif stream['info']['name'][0] == STREAM_NAMES['raw_data']: | |
raw_stream = stream | |
elif stream['info']['name'][0] == STREAM_NAMES['stimuli']: | |
markers = stream | |
if filtered_stream is None or raw_stream is None: | |
raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording") | |
# Add all samples from raw and filtered signals | |
csv_list = [] | |
shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]), | |
int(raw_stream['footer']['info']['sample_count'][0])) | |
for i in range(shortest_stream): | |
if markers is not None: | |
datapoint = [filtered_stream['time_stamps'][i], | |
float(filtered_stream['time_series'][i, channel-1]), | |
raw_stream['time_series'][i, channel-1], | |
0] | |
else: | |
datapoint = [filtered_stream['time_stamps'][i], | |
float(filtered_stream['time_series'][i, channel-1]), | |
raw_stream['time_series'][i, channel-1]] | |
csv_list.append(datapoint) | |
# Add markers | |
columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"] | |
if markers is not None: | |
columns.append("online_stimulations_portiloop") | |
for time_stamp in markers['time_stamps']: | |
new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin() | |
csv_list[new_index][3] = 1 | |
return np.array(csv_list), columns | |
def offline_detect(method, data, timesteps, freq, mask): | |
# Extract only the interesting elements from the mask | |
data_masked = data[mask] | |
# Get the spindle data from the offline methods | |
time = np.arange(0, len(data)) / freq | |
time_masked = time[mask] | |
if method == "Lacourse": | |
detector = DetectSpindle(method='Lacourse2018') | |
spindles, _, _ = detect_Lacourse2018(data_masked, freq, time_masked, detector) | |
elif method == "Wamsley": | |
detector = DetectSpindle(method='Wamsley2012') | |
spindles, _, _ = detect_Wamsley2012(data_masked, freq, time_masked, detector) | |
else: | |
raise ValueError("Invalid method") | |
# Convert the spindle data to a numpy array | |
spindle_result = np.zeros(data.shape) | |
for spindle in spindles: | |
start = spindle["start"] | |
end = spindle["end"] | |
# Find index of timestep closest to start and end | |
start_index = np.argmin(np.abs(timesteps - start)) | |
end_index = np.argmin(np.abs(timesteps - end)) | |
spindle_result[start_index:end_index] = 1 | |
return spindle_result | |
def offline_filter(signal, freq): | |
# Notch filter | |
f0 = 60.0 # Frequency to be removed from signal (Hz) | |
Q = 100.0 # Quality factor | |
b, a = iirnotch(f0, Q, freq) | |
signal = filtfilt(b, a, signal) | |
# Bandpass filter | |
lowcut = 0.5 | |
highcut = 40.0 | |
order = 4 | |
b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass') | |
signal = filtfilt(b, a, signal) | |
# Detrend the signal | |
signal = detrend(signal) | |
return signal | |
def compute_output_table(irl_online_stimulations, online_stimulation, lacourse_spindles, wamsley_spindles, time_overlap_s=2.0): | |
# Count the number of spindles in this run which overlap with spindles found IRL | |
irl_spindles_count = sum(irl_online_stimulations) | |
both_online_irl = sum([1 for index, spindle in enumerate(online_stimulation)\ | |
if spindle == 1 and 1 in irl_online_stimulations[index - int((time_overlap_s / 2) * 250):index + int((time_overlap_s / 2) * 250)]]) | |
# Count the number of spindles detected by each method | |
online_stimulation_count = np.sum(online_stimulation) | |
if lacourse_spindles is not None: | |
lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0]) | |
# Count how many spindles were detected by both online and lacourse | |
both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1]) | |
if wamsley_spindles is not None: | |
wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0]) | |
# Count how many spindles were detected by both online and wamsley | |
both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1]) | |
# Create markdown table with the results | |
table = "| Method | # of Detected spindles | Overlap with Online (in tool) |\n" | |
table += "| --- | --- | --- |\n" | |
table += f"| Online in Tool | {online_stimulation_count} | {online_stimulation_count} |\n" | |
table += f"| Online detection IRL | {irl_spindles_count} | {both_online_irl} |\n" | |
if lacourse_spindles is not None: | |
table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n" | |
if wamsley_spindles is not None: | |
table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n" | |
return table | |