Spaces:
Runtime error
Runtime error
bug fix
Browse files
app.py
CHANGED
@@ -38,12 +38,7 @@ def raw_preds_to_df(raw,quantiles = None):
|
|
38 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
39 |
return preds_df
|
40 |
|
41 |
-
def prepare_dataset(parameters, df, rain, temperature, datepicker):
|
42 |
-
rain_mapping = {
|
43 |
-
"Yes" : 1,
|
44 |
-
"No" : 0
|
45 |
-
}
|
46 |
-
|
47 |
if rain != "Default":
|
48 |
df["MTXWTH_Day_precip"] = rain_mapping[rain]
|
49 |
|
@@ -59,7 +54,7 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
|
|
59 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
60 |
|
61 |
@st.cache_data
|
62 |
-
def predict(_model,
|
63 |
out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)
|
64 |
preds = raw_preds_to_df(out, quantiles = None)
|
65 |
|
@@ -113,6 +108,10 @@ def main():
|
|
113 |
## Initiate Data
|
114 |
parameters, df = load_data()
|
115 |
model = init_model()
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# Start App
|
118 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
@@ -146,7 +145,7 @@ def main():
|
|
146 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
147 |
|
148 |
if st.button("Forecast Sales", type="primary"):
|
149 |
-
dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker)
|
150 |
preds = predict(model, dataloader)
|
151 |
generate_plot(df, preds)
|
152 |
|
|
|
38 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
39 |
return preds_df
|
40 |
|
41 |
+
def prepare_dataset(parameters, df, rain, temperature, datepicker, mapping):
|
|
|
|
|
|
|
|
|
|
|
42 |
if rain != "Default":
|
43 |
df["MTXWTH_Day_precip"] = rain_mapping[rain]
|
44 |
|
|
|
54 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
55 |
|
56 |
@st.cache_data
|
57 |
+
def predict(_model, _dataloader):
|
58 |
out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)
|
59 |
preds = raw_preds_to_df(out, quantiles = None)
|
60 |
|
|
|
108 |
## Initiate Data
|
109 |
parameters, df = load_data()
|
110 |
model = init_model()
|
111 |
+
rain_mapping = {
|
112 |
+
"Yes" : 1,
|
113 |
+
"No" : 0
|
114 |
+
}
|
115 |
|
116 |
# Start App
|
117 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
|
|
145 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
146 |
|
147 |
if st.button("Forecast Sales", type="primary"):
|
148 |
+
dataloader = prepare_dataset(parameters, df, rain, temperature, datepicker, rain_mapping)
|
149 |
preds = predict(model, dataloader)
|
150 |
generate_plot(df, preds)
|
151 |
|