File size: 5,437 Bytes
7d40d1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import matplotlib.pyplot as plt
import time
import numpy as np
import pandas as pd
from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator
from portiloop.src.detection import SleepSpindleRealTimeDetector

from portiloop.src.stimulation import UpStateDelayer
plt.switch_backend('agg')
from portiloop.src.processing import FilterPipeline


def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer):

    # Read the csv file to a numpy array
    data_whole = np.loadtxt(csv_file.name, delimiter=',')

    # Get the data from the selected channel
    detect_channel = int(detect_channel)
    freq = int(freq)
    data = data_whole[:, detect_channel - 1]

    # Create the detector and the stimulator
    detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
    stimulator = DemoSleepSpindleRealTimeStimulator()
    if spindle_detection_mode != 'Fast':
        delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer)
        stimulator.add_delayer(delayer)
    
    # Create the filtering pipeline
    if filtering:
        filter = FilterPipeline(nb_channels=1, sampling_rate=freq)

    # Plotting variables
    points = []
    activations = []
    delayed_activations = []

    # Go through the data
    for index, point in enumerate(data):
        # Step the delayer if exists
        if spindle_detection_mode != 'Fast':
            delayed = delayer.step(point)
            if delayed:
                delayed_activations.append(1)
            else:
                delayed_activations.append(0)

        # Filter the data
        if filtering:
            filtered_point = filter.filter(np.array([point]))
        else:
            filtered_point = point
        
        filtered_point = filtered_point.tolist()

        # Detect the spindles
        result = detector.detect([filtered_point])

        # Stimulate if necessary
        stim = stimulator.stimulate(result)
        if stim:
            activations.append(1)
        else:
            activations.append(0)
        
        # Add data to plotting buffer
        points.append(filtered_point[0])

        # Plot the data
        if index % (10 * freq) == 0:
            plt.close()
            fig = plt.figure(figsize=(20, 10))
            plt.clf()
            plt.plot(points[-10 * freq:], label="Data")
            # Draw vertical lines for activations
            for index in get_activations(activations[-10 * freq:]):
                plt.axvline(x=index, color='r', label="Fast Stimulation")
            if spindle_detection_mode != 'Fast':
                for index in get_activations(delayed_activations[-10 * freq:]):
                    plt.axvline(x=index, color='g', label="Delayed Stimulation")
            yield fig, None

    # Put all points and activations back in numpy arrays
    points = np.array(points)
    activations = np.array(activations)
    delayed_activations = np.array(delayed_activations)
    # Concatenate with the original data
    data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1)
    # Output the data to a csv file
    np.savetxt('output.csv', data_whole, delimiter=',')

    yield None, "output.csv"
        
# Function to return a list of all indexes where activations have happened
def get_activations(activations):
    return [i for i, x in enumerate(activations) if x == 1]
    

with gr.Blocks() as demo:
    gr.Markdown("Enter your csv file and click **Run Inference** to get the output.")

    # Row containing all inputs:
    with gr.Row():
        # CSV file
        csv_file = gr.UploadButton(label="CSV File", file_count="single")
        # Filtering (Boolean)
        filtering = gr.Checkbox(label="Filtering (On/Off)", value=True)
        # Threshold value
        threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
        # Detection Channel
        detect_column = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], value="1", label="Detection Column", interactive=True) 
        # Frequency
        freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Frequency", interactive=True)
        # Spindle Frequency
        spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency", interactive=True)
        # Spindle Detection Mode
        spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True)
        # Time to buffer
        time_to_buffer = gr.Slider(0, 1, value=0, step=0.01, label="Time to Buffer", interactive=True)

    # Output plot
    output_plot = gr.Plot()
    # Output file
    output_array = gr.File(label="Output CSV File")

    # Row containing all buttons:
    with gr.Row():
        # Run inference button
        run_inference = gr.Button(value="Run Inference")
        # Reset button
        reset = gr.Button(value="Reset", variant="secondary")
    run_inference.click(fn=do_treatment, inputs=[csv_file, filtering, threshold, detect_column, freq, spindle_freq, spindle_detection_mode, time_to_buffer], outputs=[output_plot, output_array])

demo.queue()
demo.launch()