leuschnm commited on
Commit
7da7252
·
1 Parent(s): d4b9c62
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -38,6 +38,7 @@ def raw_preds_to_df(raw,quantiles = None):
38
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
39
  return preds_df
40
 
 
41
  def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
42
  if rain != "Default":
43
  df["MTXWTH_Day_precip"] = mapping[rain]
@@ -59,25 +60,21 @@ def predict(_model, _dataloader, datepicker):
59
  preds = raw_preds_to_df(out, quantiles = None)
60
  return preds[["pred_idx", "Group", "pred"]]
61
 
62
- def update_plot(df, preds, axs, fig):
 
63
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
64
  df = df[~df["pred"].isna()]
65
- df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
 
66
 
67
- axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
68
- axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
69
- axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
70
- axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
71
- return fig, axs
72
 
73
- #@st.cache_resource(
74
  def generate_plot(df):
75
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
- df["sales"] = df["sales"].replace(0.0, np.nan)
77
  # Plot scatter plots for each group
78
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
79
-
80
  axs[0, 0].set_title('Article Group 1')
 
81
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey')
82
  axs[0, 1].set_title('Article Group 2')
83
 
@@ -86,6 +83,12 @@ def generate_plot(df):
86
 
87
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
88
  axs[1, 1].set_title('Article Group 4')
 
 
 
 
 
 
89
  plt.tight_layout()
90
  return fig, axs
91
 
@@ -141,14 +144,14 @@ def main():
141
 
142
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
143
 
144
- fig, axs = generate_plot(df.copy())
145
- st.pyplot(fig)
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, rain_mapping)
149
  preds = predict(model, dataloader, st.session_state.date)
150
- update_plot(df, preds, axs, fig)
151
 
 
 
 
152
 
153
 
154
  if __name__ == '__main__':
 
38
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
39
  return preds_df
40
 
41
+ @st.cache_data
42
  def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
43
  if rain != "Default":
44
  df["MTXWTH_Day_precip"] = mapping[rain]
 
60
  preds = raw_preds_to_df(out, quantiles = None)
61
  return preds[["pred_idx", "Group", "pred"]]
62
 
63
+ @st.cache_data
64
+ def adjust_data_for_plot(df, preds):
65
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
66
  df = df[~df["pred"].isna()]
67
+ df["sales"] = df["sales"].replace(0.0, np.nan)
68
+ return df
69
 
 
 
 
 
 
70
 
 
71
  def generate_plot(df):
72
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
73
+
74
  # Plot scatter plots for each group
75
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
 
76
  axs[0, 0].set_title('Article Group 1')
77
+
78
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey')
79
  axs[0, 1].set_title('Article Group 2')
80
 
 
83
 
84
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
85
  axs[1, 1].set_title('Article Group 4')
86
+
87
+ axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
88
+ axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
89
+ axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
90
+ axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
91
+
92
  plt.tight_layout()
93
  return fig, axs
94
 
 
144
 
145
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
146
 
 
 
147
 
148
  if st.button("Forecast Sales", type="primary"):
149
  dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, rain_mapping)
150
  preds = predict(model, dataloader, st.session_state.date)
 
151
 
152
+ data_plot = adjust_data_for_plot(df.copy(), preds)
153
+ fig, axs = generate_plot(df.copy())
154
+ st.pyplot(fig)
155
 
156
 
157
  if __name__ == '__main__':