leuschnm commited on
Commit
a852418
·
1 Parent(s): 87d051f
Files changed (1) hide show
  1. app.py +22 -23
app.py CHANGED
@@ -109,14 +109,6 @@ def init_model():
109
  return model
110
 
111
  def main():
112
- ## Initiate Data
113
- parameters, df = load_data()
114
- model = init_model()
115
- rain_mapping = {
116
- "Yes" : 1,
117
- "No" : 0
118
- }
119
-
120
  # Start App
121
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
122
 
@@ -141,29 +133,36 @@ def main():
141
 
142
  ### Experiments
143
  """)
144
- rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
145
-
146
- temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
147
 
148
- datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
149
-
150
-
151
  try:
152
  # check if the key exists in session state
153
- _ = st.session_state.pressed
 
 
154
  except AttributeError:
155
  # otherwise set it to false
156
- st.session_state.pressed = False
 
 
157
 
158
- if st.button("Forecast Sales", type="primary") or st.session_state.pressed:
159
- dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, rain_mapping)
160
- preds = predict(model, dataloader, st.session_state.date)
 
 
 
 
161
 
162
- data_plot = adjust_data_for_plot(df.copy(), preds)
163
- fig, axs = generate_plot(data_plot)
164
- st.session_state.pressed = True
165
- st.pyplot(fig)
166
 
 
 
 
 
 
 
 
167
 
168
  if __name__ == '__main__':
169
  main()
 
109
  return model
110
 
111
  def main():
 
 
 
 
 
 
 
 
112
  # Start App
113
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
114
 
 
133
 
134
  ### Experiments
135
  """)
 
 
 
136
 
 
 
 
137
  try:
138
  # check if the key exists in session state
139
+ _ = st.session_state.rain
140
+ _ = st.session_state.temperature
141
+ _ = st.session_state.date
142
  except AttributeError:
143
  # otherwise set it to false
144
+ st.session_state.rain = 'Default'
145
+ st.session_state.temperature = 0.0
146
+ st.session_state.date = datetime.date(2022, 10, 24)
147
 
148
+ RAIN_MAPPING = {
149
+ "Yes" : 1,
150
+ "No" : 0
151
+ }
152
+
153
+ parameters, df = load_data()
154
+ model = init_model()
155
 
156
+ dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
157
+ preds = predict(model, dataloader, st.session_state.date)
 
 
158
 
159
+ data_plot = adjust_data_for_plot(df.copy(), preds)
160
+ fig, _ = generate_plot(data_plot)
161
+
162
+ st.pyplot(fig)
163
+ rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
164
+ temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
165
+ datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
166
 
167
  if __name__ == '__main__':
168
  main()