Milo Sobral
Finished last few changes
111f264
raw
history blame
4.82 kB
import matplotlib.pyplot as plt
import numpy as np
from portiloop.src.detection import SleepSpindleRealTimeDetector
plt.switch_backend('agg')
from portiloop.src.processing import FilterPipeline
from portiloop.src.demo.utils import xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
import gradio as gr
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
# Get the options from the checkbox group
offline_filtering = 0 in detect_filter_opts
lacourse = 1 in detect_filter_opts
wamsley = 2 in detect_filter_opts
online_filtering = 3 in detect_filter_opts
online_detection = 4 in detect_filter_opts
# Make sure the inputs make sense:
if not offline_filtering and (lacourse or wamsley):
raise gr.Error("You can't use the offline detection methods without offline filtering.")
if not online_filtering and online_detection:
raise gr.Error("You can't use the online detection without online filtering.")
if xdf_file is None:
raise gr.Error("Please upload a .xdf file.")
freq = int(freq)
# Read the xdf file to a numpy array
print("Loading xdf file...")
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
# Do the offline filtering of the data
if offline_filtering:
print("Filtering offline...")
offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
# Expand the dimension of the filtered data to match the shape of the other columns
offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
columns.append("offline_filtered_signal")
# Do Wamsley's method
if wamsley:
print("Running Wamsley detection...")
wamsley_data = offline_detect("Wamsley", \
data_whole[:, columns.index("offline_filtered_signal")],\
data_whole[:, columns.index("time_stamps")],\
freq)
wamsley_data = np.expand_dims(wamsley_data, axis=1)
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
columns.append("wamsley_spindles")
# Do Lacourse's method
if lacourse:
print("Running Lacourse detection...")
lacourse_data = offline_detect("Lacourse", \
data_whole[:, columns.index("offline_filtered_signal")],\
data_whole[:, columns.index("time_stamps")],\
freq)
lacourse_data = np.expand_dims(lacourse_data, axis=1)
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
columns.append("lacourse_spindles")
# Get the data from the raw signal column
data = data_whole[:, columns.index("raw_signal")]
# Create the online filtering pipeline
if online_filtering:
filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
# Create the detector
if online_detection:
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
stimulator = OfflineSleepSpindleRealTimeStimulator()
if online_filtering or online_detection:
print("Running online filtering and detection...")
points = []
online_activations = []
# Go through the data
for index, point in enumerate(data):
# Filter the data
if online_filtering:
filtered_point = filter.filter(np.array([point]))
else:
filtered_point = point
filtered_point = filtered_point.tolist()
points.append(filtered_point[0])
if online_detection:
# Detect the spindles
result = detector.detect([filtered_point])
# Stimulate if necessary
stim = stimulator.stimulate(result)
if stim:
online_activations.append(1)
else:
online_activations.append(0)
if online_filtering:
online_filtered = np.array(points)
online_filtered = np.expand_dims(online_filtered, axis=1)
data_whole = np.concatenate((data_whole, online_filtered), axis=1)
columns.append("online_filtered_signal")
if online_detection:
online_activations = np.array(online_activations)
online_activations = np.expand_dims(online_activations, axis=1)
data_whole = np.concatenate((data_whole, online_activations), axis=1)
columns.append("online_stimulations")
print("Saving output...")
# Output the data to a csv file
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
print("Done!")
return "output.csv"