File size: 4,823 Bytes
2cb7306
 
 
 
 
 
 
 
 
111f264
 
 
 
 
 
 
2cb7306
 
 
 
 
 
 
 
111f264
 
 
2cb7306
 
 
 
 
 
 
111f264
2cb7306
 
 
 
 
 
 
 
111f264
2cb7306
 
 
 
 
 
 
 
 
 
111f264
2cb7306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111f264
 
2cb7306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111f264
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import matplotlib.pyplot as plt
import numpy as np
from portiloop.src.detection import SleepSpindleRealTimeDetector
plt.switch_backend('agg')
from portiloop.src.processing import FilterPipeline
from portiloop.src.demo.utils import xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
import gradio as gr


def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
    # Get the options from the checkbox group
    offline_filtering = 0 in detect_filter_opts
    lacourse = 1 in detect_filter_opts
    wamsley = 2 in detect_filter_opts
    online_filtering = 3 in detect_filter_opts
    online_detection = 4 in detect_filter_opts

    # Make sure the inputs make sense:
    if not offline_filtering and (lacourse or wamsley):
        raise gr.Error("You can't use the offline detection methods without offline filtering.")

    if not online_filtering and online_detection:
        raise gr.Error("You can't use the online detection without online filtering.")

    if xdf_file is None:
        raise gr.Error("Please upload a .xdf file.")

    freq = int(freq)

    # Read the xdf file to a numpy array
    print("Loading xdf file...")
    data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
    # Do the offline filtering of the data
    if offline_filtering:
        print("Filtering offline...")
        offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
        # Expand the dimension of the filtered data to match the shape of the other columns
        offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
        data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
        columns.append("offline_filtered_signal")

    #  Do Wamsley's method
    if wamsley:
        print("Running Wamsley detection...")
        wamsley_data = offline_detect("Wamsley", \
            data_whole[:, columns.index("offline_filtered_signal")],\
                data_whole[:, columns.index("time_stamps")],\
                    freq)
        wamsley_data = np.expand_dims(wamsley_data, axis=1)
        data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
        columns.append("wamsley_spindles")

    # Do Lacourse's method
    if lacourse:
        print("Running Lacourse detection...")
        lacourse_data = offline_detect("Lacourse", \
            data_whole[:, columns.index("offline_filtered_signal")],\
                data_whole[:, columns.index("time_stamps")],\
                    freq)
        lacourse_data = np.expand_dims(lacourse_data, axis=1)
        data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
        columns.append("lacourse_spindles")

    # Get the data from the raw signal column
    data = data_whole[:, columns.index("raw_signal")]
    
    # Create the online filtering pipeline
    if online_filtering:
        filter = FilterPipeline(nb_channels=1, sampling_rate=freq)

    # Create the detector
    if online_detection:
        detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
        stimulator = OfflineSleepSpindleRealTimeStimulator()

    if online_filtering or online_detection:
        print("Running online filtering and detection...")

        points = []
        online_activations = []

        # Go through the data
        for index, point in enumerate(data):
            # Filter the data
            if online_filtering:
                filtered_point = filter.filter(np.array([point]))
            else:
                filtered_point = point
            filtered_point = filtered_point.tolist()
            points.append(filtered_point[0])

            if online_detection:
                # Detect the spindles
                result = detector.detect([filtered_point])

                # Stimulate if necessary
                stim = stimulator.stimulate(result)
                if stim:
                    online_activations.append(1)
                else:
                    online_activations.append(0)

    if online_filtering:
        online_filtered = np.array(points)
        online_filtered = np.expand_dims(online_filtered, axis=1)
        data_whole = np.concatenate((data_whole, online_filtered), axis=1)
        columns.append("online_filtered_signal")

    if online_detection:
        online_activations = np.array(online_activations)
        online_activations = np.expand_dims(online_activations, axis=1)
        data_whole = np.concatenate((data_whole, online_activations), axis=1)
        columns.append("online_stimulations")

    print("Saving output...")
    # Output the data to a csv file
    np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")

    print("Done!")
    return "output.csv"