Spaces:
Runtime error
Runtime error
image
Browse files- app.py +17 -8
- data/image.png +0 -0
app.py
CHANGED
@@ -16,6 +16,8 @@ from pytorch_forecasting import (
|
|
16 |
TemporalFusionTransformer,
|
17 |
)
|
18 |
|
|
|
|
|
19 |
## Functions
|
20 |
def raw_preds_to_df(raw,quantiles = None):
|
21 |
"""
|
@@ -38,7 +40,6 @@ def raw_preds_to_df(raw,quantiles = None):
|
|
38 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
39 |
return preds_df
|
40 |
|
41 |
-
#@st.cache_data
|
42 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
43 |
if rain != "Default":
|
44 |
df["MTXWTH_Day_precip"] = mapping[rain]
|
@@ -54,13 +55,11 @@ def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
|
54 |
df = TimeSeriesDataSet.from_parameters(_parameters, df)
|
55 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
56 |
|
57 |
-
#@st.cache_data
|
58 |
def predict(_model, _dataloader, datepicker):
|
59 |
out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
|
60 |
preds = raw_preds_to_df(out, quantiles = None)
|
61 |
return preds[["pred_idx", "Group", "pred"]]
|
62 |
|
63 |
-
#@st.cache_data
|
64 |
def adjust_data_for_plot(df, preds):
|
65 |
df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
|
66 |
df = df[~df["pred"].isna()]
|
@@ -111,7 +110,11 @@ def init_model():
|
|
111 |
def main():
|
112 |
# Start App
|
113 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
114 |
-
|
|
|
|
|
|
|
|
|
115 |
st.markdown(body = """
|
116 |
### Abstract
|
117 |
Multi-horizon forecasting often contains a complex mix of inputs – including
|
@@ -132,6 +135,10 @@ def main():
|
|
132 |
interpretability use cases of TFT.
|
133 |
|
134 |
### Experiments
|
|
|
|
|
|
|
|
|
135 |
""")
|
136 |
|
137 |
try:
|
@@ -158,12 +165,14 @@ def main():
|
|
158 |
|
159 |
data_plot = adjust_data_for_plot(df.copy(), preds)
|
160 |
fig, _ = generate_plot(data_plot)
|
161 |
-
|
162 |
-
st.pyplot(fig)
|
163 |
-
rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
|
164 |
-
temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
|
165 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
|
166 |
|
|
|
|
|
|
|
|
|
|
|
167 |
if __name__ == '__main__':
|
168 |
main()
|
169 |
|
|
|
16 |
TemporalFusionTransformer,
|
17 |
)
|
18 |
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
## Functions
|
22 |
def raw_preds_to_df(raw,quantiles = None):
|
23 |
"""
|
|
|
40 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
41 |
return preds_df
|
42 |
|
|
|
43 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
44 |
if rain != "Default":
|
45 |
df["MTXWTH_Day_precip"] = mapping[rain]
|
|
|
55 |
df = TimeSeriesDataSet.from_parameters(_parameters, df)
|
56 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
57 |
|
|
|
58 |
def predict(_model, _dataloader, datepicker):
|
59 |
out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
|
60 |
preds = raw_preds_to_df(out, quantiles = None)
|
61 |
return preds[["pred_idx", "Group", "pred"]]
|
62 |
|
|
|
63 |
def adjust_data_for_plot(df, preds):
|
64 |
df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
|
65 |
df = df[~df["pred"].isna()]
|
|
|
110 |
def main():
|
111 |
# Start App
|
112 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
113 |
+
|
114 |
+
|
115 |
+
image = Image.open('data/image.png')
|
116 |
+
|
117 |
+
st.image(image, caption='Coding.Waterkant Festival for AI')
|
118 |
st.markdown(body = """
|
119 |
### Abstract
|
120 |
Multi-horizon forecasting often contains a complex mix of inputs – including
|
|
|
135 |
interpretability use cases of TFT.
|
136 |
|
137 |
### Experiments
|
138 |
+
We implemented TFT for sales multi-horizon sales forecast during Coding.Waterkant.
|
139 |
+
Please try our implementation and adjust some of the training data.
|
140 |
+
|
141 |
+
Adjustments to the model and extention with Quantile forecasts are coming soon ;)
|
142 |
""")
|
143 |
|
144 |
try:
|
|
|
165 |
|
166 |
data_plot = adjust_data_for_plot(df.copy(), preds)
|
167 |
fig, _ = generate_plot(data_plot)
|
168 |
+
|
|
|
|
|
|
|
169 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
|
170 |
|
171 |
+
st.pyplot(fig)
|
172 |
+
|
173 |
+
temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
|
174 |
+
rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
|
175 |
+
|
176 |
if __name__ == '__main__':
|
177 |
main()
|
178 |
|
data/image.png
ADDED
![]() |