leuschnm commited on
Commit
68dbb2b
·
1 Parent(s): 49140ff
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -38,12 +38,7 @@ def raw_preds_to_df(raw,quantiles = None):
38
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
39
  return preds_df
40
 
41
- def prepare_dataset(parameters, df, rain, temperature, datepicker):
42
- rain_mapping = {
43
- "Yes" : 1,
44
- "No" : 0
45
- }
46
-
47
  if rain != "Default":
48
  df["MTXWTH_Day_precip"] = rain_mapping[rain]
49
 
@@ -59,7 +54,7 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
59
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
60
 
61
  @st.cache_data
62
- def predict(_model, dataloader):
63
  out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)
64
  preds = raw_preds_to_df(out, quantiles = None)
65
 
@@ -113,6 +108,10 @@ def main():
113
  ## Initiate Data
114
  parameters, df = load_data()
115
  model = init_model()
 
 
 
 
116
 
117
  # Start App
118
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
@@ -146,7 +145,7 @@ def main():
146
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
147
 
148
  if st.button("Forecast Sales", type="primary"):
149
- dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker)
150
  preds = predict(model, dataloader)
151
  generate_plot(df, preds)
152
 
 
38
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
39
  return preds_df
40
 
41
+ def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
 
 
 
 
 
42
  if rain != "Default":
43
  df["MTXWTH_Day_precip"] = rain_mapping[rain]
44
 
 
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  @st.cache_data
57
+ def predict(_model, _dataloader):
58
  out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)
59
  preds = raw_preds_to_df(out, quantiles = None)
60
 
 
108
  ## Initiate Data
109
  parameters, df = load_data()
110
  model = init_model()
111
+ rain_mapping = {
112
+ "Yes" : 1,
113
+ "No" : 0
114
+ }
115
 
116
  # Start App
117
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
 
145
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
+ dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
149
  preds = predict(model, dataloader)
150
  generate_plot(df, preds)
151