Milo Sobral commited on
Commit
7d40d1a
Β·
1 Parent(s): fcba0a9

Finished setting up the gradio demo

Browse files
.gitignore CHANGED
@@ -2,6 +2,9 @@
2
  .vscode/
3
  .idea/
4
 
 
 
 
5
  # Vagrant
6
  .vagrant/
7
 
 
2
  .vscode/
3
  .idea/
4
 
5
+ # Output from the demo
6
+ output.csv
7
+
8
  # Vagrant
9
  .vagrant/
10
 
portiloop/src/demo/demo.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import time
4
+ import numpy as np
5
+ import pandas as pd
6
+ from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator
7
+ from portiloop.src.detection import SleepSpindleRealTimeDetector
8
+
9
+ from portiloop.src.stimulation import UpStateDelayer
10
+ plt.switch_backend('agg')
11
+ from portiloop.src.processing import FilterPipeline
12
+
13
+
14
+ def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer):
15
+
16
+ # Read the csv file to a numpy array
17
+ data_whole = np.loadtxt(csv_file.name, delimiter=',')
18
+
19
+ # Get the data from the selected channel
20
+ detect_channel = int(detect_channel)
21
+ freq = int(freq)
22
+ data = data_whole[:, detect_channel - 1]
23
+
24
+ # Create the detector and the stimulator
25
+ detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
26
+ stimulator = DemoSleepSpindleRealTimeStimulator()
27
+ if spindle_detection_mode != 'Fast':
28
+ delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer)
29
+ stimulator.add_delayer(delayer)
30
+
31
+ # Create the filtering pipeline
32
+ if filtering:
33
+ filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
34
+
35
+ # Plotting variables
36
+ points = []
37
+ activations = []
38
+ delayed_activations = []
39
+
40
+ # Go through the data
41
+ for index, point in enumerate(data):
42
+ # Step the delayer if exists
43
+ if spindle_detection_mode != 'Fast':
44
+ delayed = delayer.step(point)
45
+ if delayed:
46
+ delayed_activations.append(1)
47
+ else:
48
+ delayed_activations.append(0)
49
+
50
+ # Filter the data
51
+ if filtering:
52
+ filtered_point = filter.filter(np.array([point]))
53
+ else:
54
+ filtered_point = point
55
+
56
+ filtered_point = filtered_point.tolist()
57
+
58
+ # Detect the spindles
59
+ result = detector.detect([filtered_point])
60
+
61
+ # Stimulate if necessary
62
+ stim = stimulator.stimulate(result)
63
+ if stim:
64
+ activations.append(1)
65
+ else:
66
+ activations.append(0)
67
+
68
+ # Add data to plotting buffer
69
+ points.append(filtered_point[0])
70
+
71
+ # Plot the data
72
+ if index % (10 * freq) == 0:
73
+ plt.close()
74
+ fig = plt.figure(figsize=(20, 10))
75
+ plt.clf()
76
+ plt.plot(points[-10 * freq:], label="Data")
77
+ # Draw vertical lines for activations
78
+ for index in get_activations(activations[-10 * freq:]):
79
+ plt.axvline(x=index, color='r', label="Fast Stimulation")
80
+ if spindle_detection_mode != 'Fast':
81
+ for index in get_activations(delayed_activations[-10 * freq:]):
82
+ plt.axvline(x=index, color='g', label="Delayed Stimulation")
83
+ yield fig, None
84
+
85
+ # Put all points and activations back in numpy arrays
86
+ points = np.array(points)
87
+ activations = np.array(activations)
88
+ delayed_activations = np.array(delayed_activations)
89
+ # Concatenate with the original data
90
+ data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1)
91
+ # Output the data to a csv file
92
+ np.savetxt('output.csv', data_whole, delimiter=',')
93
+
94
+ yield None, "output.csv"
95
+
96
+ # Function to return a list of all indexes where activations have happened
97
+ def get_activations(activations):
98
+ return [i for i, x in enumerate(activations) if x == 1]
99
+
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("Enter your csv file and click **Run Inference** to get the output.")
103
+
104
+ # Row containing all inputs:
105
+ with gr.Row():
106
+ # CSV file
107
+ csv_file = gr.UploadButton(label="CSV File", file_count="single")
108
+ # Filtering (Boolean)
109
+ filtering = gr.Checkbox(label="Filtering (On/Off)", value=True)
110
+ # Threshold value
111
+ threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
112
+ # Detection Channel
113
+ detect_column = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], value="1", label="Detection Column", interactive=True)
114
+ # Frequency
115
+ freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Frequency", interactive=True)
116
+ # Spindle Frequency
117
+ spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency", interactive=True)
118
+ # Spindle Detection Mode
119
+ spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True)
120
+ # Time to buffer
121
+ time_to_buffer = gr.Slider(0, 1, value=0, step=0.01, label="Time to Buffer", interactive=True)
122
+
123
+ # Output plot
124
+ output_plot = gr.Plot()
125
+ # Output file
126
+ output_array = gr.File(label="Output CSV File")
127
+
128
+ # Row containing all buttons:
129
+ with gr.Row():
130
+ # Run inference button
131
+ run_inference = gr.Button(value="Run Inference")
132
+ # Reset button
133
+ reset = gr.Button(value="Reset", variant="secondary")
134
+ run_inference.click(fn=do_treatment, inputs=[csv_file, filtering, threshold, detect_column, freq, spindle_freq, spindle_detection_mode, time_to_buffer], outputs=[output_plot, output_array])
135
+
136
+ demo.queue()
137
+ demo.launch()
portiloop/src/demo/demo_stimulator.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from portiloop.src.stimulation import Stimulator
3
+
4
+
5
+ class DemoSleepSpindleRealTimeStimulator(Stimulator):
6
+ def __init__(self):
7
+ self.last_detected_ts = time.time()
8
+ self.wait_t = 0.4 # 400 ms
9
+
10
+ def stimulate(self, detection_signal):
11
+ stim = False
12
+ for sig in detection_signal:
13
+ # We detect a stimulation
14
+ if sig:
15
+ # Record time of stimulation
16
+ ts = time.time()
17
+
18
+ # Check if time since last stimulation is long enough
19
+ if ts - self.last_detected_ts > self.wait_t:
20
+ if self.delayer is not None:
21
+ # If we have a delayer, notify it
22
+ self.delayer.detected()
23
+ stim = True
24
+
25
+ self.last_detected_ts = ts
26
+ return stim
27
+
28
+ def add_delayer(self, delayer):
29
+ self.delayer = delayer
30
+ self.delayer.stimulate = lambda: True
portiloop/src/detection.py CHANGED
@@ -5,6 +5,8 @@ from portiloop.src import ADS
5
 
6
  if ADS:
7
  from pycoral.utils import edgetpu
 
 
8
  import numpy as np
9
 
10
 
@@ -53,7 +55,10 @@ class SleepSpindleRealTimeDetector(Detector):
53
 
54
  self.interpreters = []
55
  for i in range(self.num_models_parallel):
56
- self.interpreters.append(edgetpu.make_interpreter(model_path))
 
 
 
57
  self.interpreters[i].allocate_tensors()
58
  self.interpreter_counter = 0
59
 
@@ -76,6 +81,10 @@ class SleepSpindleRealTimeDetector(Detector):
76
  super().__init__(threshold)
77
 
78
  def detect(self, datapoints):
 
 
 
 
79
  res = []
80
  for inp in datapoints:
81
  result = self.add_datapoint(inp)
 
5
 
6
  if ADS:
7
  from pycoral.utils import edgetpu
8
+ else:
9
+ import tensorflow as tf
10
  import numpy as np
11
 
12
 
 
55
 
56
  self.interpreters = []
57
  for i in range(self.num_models_parallel):
58
+ if ADS:
59
+ self.interpreters.append(edgetpu.make_interpreter(model_path))
60
+ else:
61
+ self.interpreters.append(tf.lite.Interpreter(model_path=model_path))
62
  self.interpreters[i].allocate_tensors()
63
  self.interpreter_counter = 0
64
 
 
81
  super().__init__(threshold)
82
 
83
  def detect(self, datapoints):
84
+ """
85
+ Takes datapoints as input and outputs a detection signal.
86
+ datapoints is a list of lists of n channels: may contain several datapoints.
87
+ """
88
  res = []
89
  for inp in datapoints:
90
  result = self.add_datapoint(inp)
portiloop/{demo β†’ src/hardware/demo}/acquisition_demo.py RENAMED
File without changes
portiloop/{demo β†’ src/hardware/demo}/demo_net.py RENAMED
File without changes
portiloop/{demo β†’ src/hardware/demo}/led_demo.py RENAMED
File without changes
portiloop/src/stimulation.py CHANGED
@@ -142,7 +142,7 @@ class SleepSpindleRealTimeStimulator(Stimulator):
142
 
143
  def add_delayer(self, delayer):
144
  self.delayer = delayer
145
- self.delayer.stimulate = lambda x: self.send_stimulation("DELAY_STIM", True)
146
 
147
  # Class that delays stimulation to always stimulate peak or through
148
  class UpStateDelayer:
@@ -182,7 +182,7 @@ class UpStateDelayer:
182
  return False
183
  elif self.state == States.DELAYING:
184
  # Check if we are done delaying
185
- if time.time() - self.time_started >= self.time_to_wait():
186
  # Actually stimulate the patient after the delay
187
  if self.stimulate is not None:
188
  self.stimulate()
 
142
 
143
  def add_delayer(self, delayer):
144
  self.delayer = delayer
145
+ self.delayer.stimulate = lambda: self.send_stimulation("DELAY_STIM", True)
146
 
147
  # Class that delays stimulation to always stimulate peak or through
148
  class UpStateDelayer:
 
182
  return False
183
  elif self.state == States.DELAYING:
184
  # Check if we are done delaying
185
+ if time.time() - self.time_started >= self.time_to_wait:
186
  # Actually stimulate the patient after the delay
187
  if self.stimulate is not None:
188
  self.stimulate()