leuschnm commited on
Commit
0ebac9a
·
1 Parent(s): 072ce3d
Files changed (1) hide show
  1. app.py +17 -18
app.py CHANGED
@@ -40,7 +40,7 @@ def raw_preds_to_df(raw,quantiles = None):
40
 
41
  def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
42
  if rain != "Default":
43
- df["MTXWTH_Day_precip"] = rain_mapping[rain]
44
 
45
  df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
46
  df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature
@@ -59,37 +59,34 @@ def predict(_model, _dataloader, datepicker):
59
  preds = raw_preds_to_df(out, quantiles = None)
60
  return preds[["pred_idx", "Group", "pred"]]
61
 
62
-
63
- def generate_plot(df, preds):
64
- fig, axs = plt.subplots(2, 2, figsize=(8, 6))
65
-
66
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
67
  df = df[~df["pred"].isna()]
68
  df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
 
 
 
 
 
 
 
 
 
 
 
69
  # Plot scatter plots for each group
70
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
71
- axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
72
- axs[0, 0].set_title('Article Group 1')
73
 
 
74
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey')
75
- axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
76
  axs[0, 1].set_title('Article Group 2')
77
 
78
  axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='grey')
79
- axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
80
  axs[1, 0].set_title('Article Group 3')
81
 
82
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
83
- axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
84
  axs[1, 1].set_title('Article Group 4')
85
-
86
- # Adjust spacing between subplots
87
  plt.tight_layout()
88
-
89
- #for ax in axs.flat:
90
- # ax.set_xlim(df['Date'].min(), df['Date'].max())
91
- # ax.set_ylim(df['sales'].min(), df['sales'].max())
92
-
93
  st.pyplot(fig)
94
 
95
  @st.cache_data
@@ -145,10 +142,12 @@ def main():
145
 
146
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
147
 
 
 
148
  if st.button("Forecast Sales", type="primary"):
149
  dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
150
  preds = predict(model, dataloader, datepicker)
151
- generate_plot(df, preds)
152
 
153
  if __name__ == '__main__':
154
  main()
 
40
 
41
  def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
42
  if rain != "Default":
43
+ df["MTXWTH_Day_precip"] = mapping[rain]
44
 
45
  df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
46
  df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature
 
59
  preds = raw_preds_to_df(out, quantiles = None)
60
  return preds[["pred_idx", "Group", "pred"]]
61
 
62
+ def update_plot(df, preds):
 
 
 
63
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
64
  df = df[~df["pred"].isna()]
65
  df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
66
+
67
+ axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
68
+ axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
69
+ axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
70
+ axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
71
+ return st.pyplot(fig)
72
+
73
+ @st.cache_resource
74
+ def generate_plot():
75
+ fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
+
77
  # Plot scatter plots for each group
78
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
 
 
79
 
80
+ axs[0, 0].set_title('Article Group 1')
81
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey')
 
82
  axs[0, 1].set_title('Article Group 2')
83
 
84
  axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='grey')
 
85
  axs[1, 0].set_title('Article Group 3')
86
 
87
  axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
 
88
  axs[1, 1].set_title('Article Group 4')
 
 
89
  plt.tight_layout()
 
 
 
 
 
90
  st.pyplot(fig)
91
 
92
  @st.cache_data
 
142
 
143
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
144
 
145
+ generate_plot()
146
+
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
149
  preds = predict(model, dataloader, datepicker)
150
+ update_plot(df, preds)
151
 
152
  if __name__ == '__main__':
153
  main()