leuschnm commited on
Commit
fd981f8
·
1 Parent(s): e8ee94c
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -70,7 +70,6 @@ def update_plot(df, preds, axs, fig):
70
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
71
  return fig, axs
72
 
73
- @st.cache_resource
74
  def generate_plot(df):
75
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
  df["sales"] = df["sales"].replace(0.0, np.nan)
@@ -135,18 +134,17 @@ def main():
135
 
136
  ### Experiments
137
  """)
138
- rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
139
 
140
- temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
141
 
142
- datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
143
 
144
  fig, axs = generate_plot(df.copy())
145
-
146
  st.pyplot(fig)
147
 
148
  if st.button("Forecast Sales", type="primary"):
149
- dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
150
  preds = predict(model, dataloader, datepicker)
151
  update_plot(df, preds, axs, fig)
152
 
 
70
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
71
  return fig, axs
72
 
 
73
  def generate_plot(df):
74
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
75
  df["sales"] = df["sales"].replace(0.0, np.nan)
 
134
 
135
  ### Experiments
136
  """)
137
+ rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
138
 
139
+ temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
140
 
141
+ datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
142
 
143
  fig, axs = generate_plot(df.copy())
 
144
  st.pyplot(fig)
145
 
146
  if st.button("Forecast Sales", type="primary"):
147
+ dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, rain_mapping)
148
  preds = predict(model, dataloader, datepicker)
149
  update_plot(df, preds, axs, fig)
150