Spaces:
Sleeping
Sleeping
File size: 6,764 Bytes
2cb7306 986653d 2cb7306 1737659 2cb7306 1737659 111f264 2cb7306 111f264 2cb7306 bf3350c 2cb7306 111f264 2cb7306 bf3350c 2cb7306 111f264 2cb7306 bf3350c 2cb7306 111f264 2cb7306 bf3350c 2cb7306 1737659 986653d 2cb7306 111f264 2cb7306 986653d 2cb7306 986653d 2cb7306 986653d 2cb7306 d4bce82 804607c 95e2338 804607c 95e2338 2cb7306 95e2338 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
from portiloop.src.detection import SleepSpindleRealTimeDetector
from portiloop.src.stimulation import UpStateDelayer
from portiloop.src.processing import FilterPipeline
from portiloop.src.demo.utils import OfflineIsolatedSpindleRealTimeStimulator, OfflineSpindleTrainRealTimeStimulator, compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
import gradio as gr
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, detect_trains, stimulation_phase="Fast", buffer_time=0.25):
# Get the options from the checkbox group
offline_filtering = 0 in detect_filter_opts
lacourse = 1 in detect_filter_opts
wamsley = 2 in detect_filter_opts
online_filtering = 3 in detect_filter_opts
online_detection = 4 in detect_filter_opts
# Make sure the inputs make sense:
if not offline_filtering and (lacourse or wamsley):
raise gr.Error("You can't use the offline detection methods without offline filtering.")
if not online_filtering and online_detection:
raise gr.Error("You can't use the online detection without online filtering.")
if xdf_file is None:
raise gr.Error("Please upload a .xdf file.")
freq = int(freq)
# Read the xdf file to a numpy array
print("Loading xdf file...")
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
# Do the offline filtering of the data
if offline_filtering:
print("Filtering offline...")
offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
# Expand the dimension of the filtered data to match the shape of the other columns
offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
columns.append("offline_filtered_signal")
# Do the sleep staging approximation
if wamsley or lacourse:
print("Sleep staging...")
mask = sleep_stage(data_whole[:, columns.index("offline_filtered_signal")], threshold=150, group_size=100)
# Do Wamsley's method
if wamsley:
print("Running Wamsley detection...")
wamsley_data = offline_detect("Wamsley", \
data_whole[:, columns.index("offline_filtered_signal")],\
data_whole[:, columns.index("time_stamps")],\
freq, mask)
wamsley_data = np.expand_dims(wamsley_data, axis=1)
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
columns.append("wamsley_spindles")
# Do Lacourse's method
if lacourse:
print("Running Lacourse detection...")
lacourse_data = offline_detect("Lacourse", \
data_whole[:, columns.index("offline_filtered_signal")],\
data_whole[:, columns.index("time_stamps")],\
freq, mask)
lacourse_data = np.expand_dims(lacourse_data, axis=1)
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
columns.append("lacourse_spindles")
# Get the data from the raw signal column
data = data_whole[:, columns.index("raw_signal")]
# Create the online filtering pipeline
if online_filtering:
filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
# Create the detector
if online_detection:
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
if detect_trains == "All Spindles":
stimulator = OfflineSleepSpindleRealTimeStimulator()
elif detect_trains == "Trains":
stimulator = OfflineSpindleTrainRealTimeStimulator()
elif detect_trains == "Isolated & First":
stimulator = OfflineIsolatedSpindleRealTimeStimulator()
if stimulation_phase != "Fast":
stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
stimulator.add_delayer(stimulation_delayer)
if online_filtering or online_detection:
print("Running online filtering and detection...")
points = []
online_activations = []
delayed_stims = []
# Go through the data
for index, point in enumerate(data):
# Filter the data
if online_filtering:
filtered_point = filter.filter(np.array([point]))
else:
filtered_point = point
filtered_point = filtered_point.tolist()
points.append(filtered_point[0])
if online_detection:
# Detect the spindles
result = detector.detect([filtered_point])
if stimulation_phase != "Fast":
delayed_stim = stimulation_delayer.step_timesteps(filtered_point[0])
if delayed_stim:
delayed_stims.append(1)
else:
delayed_stims.append(0)
# Stimulate if necessary
stim = stimulator.stimulate(result)
if stim:
online_activations.append(1)
else:
online_activations.append(0)
if online_filtering:
online_filtered = np.array(points)
online_filtered = np.expand_dims(online_filtered, axis=1)
data_whole = np.concatenate((data_whole, online_filtered), axis=1)
columns.append("online_filtered_signal")
if online_detection:
online_activations = np.array(online_activations)
online_activations = np.expand_dims(online_activations, axis=1)
data_whole = np.concatenate((data_whole, online_activations), axis=1)
columns.append("online_stimulations")
if stimulation_phase != "Fast":
delayed_stims = np.array(delayed_stims)
delayed_stims = np.expand_dims(delayed_stims, axis=1)
data_whole = np.concatenate((data_whole, delayed_stims), axis=1)
columns.append("delayed_stimulations")
print("Saving output...")
# Output the data to a csv file
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
# Compute the overlap of online stimulations with the
output_table = compute_output_table(
data_whole[:, columns.index("online_stimulations")],
data_whole[:, columns.index("online_stimulations_portiloop")],
data_whole[:, columns.index("lacourse_spindles")] if lacourse else None,
data_whole[:, columns.index("wamsley_spindles")] if wamsley else None,)
print("Done!")
return "output.csv", output_table |