leuschnm commited on
Commit
90de3d0
·
1 Parent(s): ac3bcbd
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -59,10 +59,10 @@ def predict(_model, _dataloader, datepicker):
59
  preds = raw_preds_to_df(out, quantiles = None)
60
  return preds[["pred_idx", "Group", "pred"]]
61
 
62
- def update_plot(df, preds):
63
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
64
  df = df[~df["pred"].isna()]
65
- #df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
66
 
67
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
68
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
@@ -73,7 +73,7 @@ def update_plot(df, preds):
73
  @st.cache_resource
74
  def generate_plot(df):
75
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
-
77
  # Plot scatter plots for each group
78
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
79
 
@@ -87,7 +87,7 @@ def generate_plot(df):
87
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
88
  axs[1, 1].set_title('Article Group 4')
89
  plt.tight_layout()
90
- st.pyplot(fig)
91
 
92
  @st.cache_data
93
  def load_data():
@@ -142,12 +142,14 @@ def main():
142
 
143
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
144
 
145
- generate_plot(df)
 
 
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
149
  preds = predict(model, dataloader, datepicker)
150
- update_plot(df, preds)
151
 
152
  if __name__ == '__main__':
153
  main()
 
59
  preds = raw_preds_to_df(out, quantiles = None)
60
  return preds[["pred_idx", "Group", "pred"]]
61
 
62
+ def update_plot(df, preds, axs):
63
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
64
  df = df[~df["pred"].isna()]
65
+ df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
66
 
67
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
68
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
 
73
  @st.cache_resource
74
  def generate_plot(df):
75
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
+ df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
77
  # Plot scatter plots for each group
78
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
79
 
 
87
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
88
  axs[1, 1].set_title('Article Group 4')
89
  plt.tight_layout()
90
+ return fig, axs
91
 
92
  @st.cache_data
93
  def load_data():
 
142
 
143
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
144
 
145
+ fig, axs = generate_plot(df)
146
+
147
+ st.pyplot(fig)
148
 
149
  if st.button("Forecast Sales", type="primary"):
150
  dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
151
  preds = predict(model, dataloader, datepicker)
152
+ update_plot(df, preds, axs)
153
 
154
  if __name__ == '__main__':
155
  main()