leuschnm commited on
Commit
8c67f44
·
1 Parent(s): cb667ba
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -49,27 +49,23 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
49
  upperbound = datepicker + datetime.timedelta(days = 30)
50
 
51
  df = df.loc[(df["Date"].dt.date>lowerbound) & (df["Date"].dt.date<=upperbound)]
 
52
 
53
  df = TimeSeriesDataSet.from_parameters(parameters, df)
54
- return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
- def predict(model, dataloader):
57
- return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
58
-
59
- #on_click=None,
60
-
61
- # %%
62
- #preds = raw_preds_to_df(out, quantiles = None)
63
 
64
- #preds = preds.merge(data_selected[['time_idx','Group','Branch','sales','weight','Date','MTXWTH_Day_precip','MTXWTH_Temp_max','MTXWTH_Temp_min']],how='left',left_on=['pred_idx','Group','Branch'],right_on=['time_idx','Group','Branch'])
65
- #preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
66
- #preds.drop(columns=['time_idx_y'],inplace=True)
67
 
68
- def generate_plot(df): #, predictions):
69
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
70
 
71
  # Plot scatter plots for each group
72
- axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
 
73
  axs[0, 0].set_title('Article Group 1')
74
 
75
  axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
@@ -137,8 +133,9 @@ def main():
137
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
138
 
139
  if st.button("Forecast Sales", type="primary"):
140
- converted_data = prepare_dataset(parameters, df, rain, temperature, datepicker)
141
- generate_plot(df)
 
142
 
143
  if __name__ == '__main__':
144
  main()
 
49
  upperbound = datepicker + datetime.timedelta(days = 30)
50
 
51
  df = df.loc[(df["Date"].dt.date>lowerbound) & (df["Date"].dt.date<=upperbound)]
52
+ dates = df["Date"]
53
 
54
  df = TimeSeriesDataSet.from_parameters(parameters, df)
55
+ return df.to_dataloader(train=False, batch_size=256,num_workers = 0), dates
56
 
57
+ def predict(model, dataloader):
58
+ out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)
59
+ preds = raw_preds_to_df(out, quantiles = None)
60
+ return preds[["Group", "pred"]]
 
 
 
61
 
 
 
 
62
 
63
+ def generate_plot(df, dates, preds)
64
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
65
 
66
  # Plot scatter plots for each group
67
+ axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey', marker='o')
68
+ axs[0, 0].plot(dates, preds.loc[preds['Group'] == '4', 'pred'], color = 'red')
69
  axs[0, 0].set_title('Article Group 1')
70
 
71
  axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
 
133
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
134
 
135
  if st.button("Forecast Sales", type="primary"):
136
+ dataloader, dates = prepare_dataset(parameters, df, rain, temperature, datepicker)
137
+ preds = predict(model, dataloader)
138
+ generate_plot(df, dates, preds)
139
 
140
  if __name__ == '__main__':
141
  main()