leuschnm commited on
Commit
83bf3df
·
1 Parent(s): fc970ed
Files changed (1) hide show
  1. app.py +4 -11
app.py CHANGED
@@ -57,20 +57,13 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
57
  def predict(_model, _dataloader, datepicker):
58
  out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
59
  preds = raw_preds_to_df(out, quantiles = None)
60
-
61
- def add_dates(group):
62
- group["date_imputed"] = [datepicker + datetime.timedelta(days=x) for x in range(30)]
63
- return group
64
-
65
- preds["date_imputed"] = preds.groupby("Group").pred.transform(add_dates)
66
-
67
- return preds[["date_imputed", "Group", "pred"]]
68
 
69
 
70
  def generate_plot(df, preds):
71
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
72
 
73
- df = pd.merge(df, pred, left_on='Date', right_on='date_imputed')
74
  # Plot scatter plots for each group
75
  axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey', marker='o')
76
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
@@ -145,10 +138,10 @@ def main():
145
 
146
  temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
147
 
148
- datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
149
 
150
  if st.button("Forecast Sales", type="primary"):
151
- dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
152
  preds = predict(model, dataloader, datepicker)
153
  generate_plot(df, preds)
154
 
 
57
  def predict(_model, _dataloader, datepicker):
58
  out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
59
  preds = raw_preds_to_df(out, quantiles = None)
60
+ return preds[["pred_idx", "Group", "pred"]
 
 
 
 
 
 
 
61
 
62
 
63
  def generate_plot(df, preds):
64
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
65
 
66
+ df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
67
  # Plot scatter plots for each group
68
  axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey', marker='o')
69
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
 
138
 
139
  temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
140
 
141
+ datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
142
 
143
  if st.button("Forecast Sales", type="primary"):
144
+ dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
145
  preds = predict(model, dataloader, datepicker)
146
  generate_plot(df, preds)
147