Spaces:
Runtime error
Runtime error
bug fix
Browse files
app.py
CHANGED
@@ -54,8 +54,8 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
|
|
54 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
55 |
|
56 |
@st.cache_data
|
57 |
-
def predict(_model, _dataloader):
|
58 |
-
out =
|
59 |
preds = raw_preds_to_df(out, quantiles = None)
|
60 |
|
61 |
date_list = [datepicker + datetime.timedelta(days=x) for x in range(30)]
|
@@ -146,7 +146,7 @@ def main():
|
|
146 |
|
147 |
if st.button("Forecast Sales", type="primary"):
|
148 |
dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
|
149 |
-
preds = predict(model, dataloader)
|
150 |
generate_plot(df, preds)
|
151 |
|
152 |
if __name__ == '__main__':
|
|
|
54 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
55 |
|
56 |
@st.cache_data
|
57 |
+
def predict(_model, _dataloader, datepicker):
|
58 |
+
out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
|
59 |
preds = raw_preds_to_df(out, quantiles = None)
|
60 |
|
61 |
date_list = [datepicker + datetime.timedelta(days=x) for x in range(30)]
|
|
|
146 |
|
147 |
if st.button("Forecast Sales", type="primary"):
|
148 |
dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
|
149 |
+
preds = predict(model, dataloader, datepicker)
|
150 |
generate_plot(df, preds)
|
151 |
|
152 |
if __name__ == '__main__':
|