leuschnm commited on
Commit
f0fac46
·
1 Parent(s): 68dbb2b
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -54,8 +54,8 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  @st.cache_data
57
- def predict(_model, _dataloader):
58
- out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)
59
  preds = raw_preds_to_df(out, quantiles = None)
60
 
61
  date_list = [datepicker + datetime.timedelta(days=x) for x in range(30)]
@@ -146,7 +146,7 @@ def main():
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
149
- preds = predict(model, dataloader)
150
  generate_plot(df, preds)
151
 
152
  if __name__ == '__main__':
 
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  @st.cache_data
57
+ def predict(_model, _dataloader, datepicker):
58
+ out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
59
  preds = raw_preds_to_df(out, quantiles = None)
60
 
61
  date_list = [datepicker + datetime.timedelta(days=x) for x in range(30)]
 
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
149
+ preds = predict(model, dataloader, datepicker)
150
  generate_plot(df, preds)
151
 
152
  if __name__ == '__main__':