Spaces:
Runtime error
Runtime error
bugfix
Browse files- .ipynb_checkpoints/app-checkpoint.py +116 -0
- app.py +1 -1
.ipynb_checkpoints/app-checkpoint.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Imports
|
2 |
+
import pickle
|
3 |
+
import warnings
|
4 |
+
import streamlit as st
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import datetime
|
11 |
+
|
12 |
+
# import torch
|
13 |
+
from torch.distributions import Normal
|
14 |
+
from pytorch_forecasting import (
|
15 |
+
TimeSeriesDataSet,
|
16 |
+
TemporalFusionTransformer,
|
17 |
+
)
|
18 |
+
|
19 |
+
## Functions
|
20 |
+
def raw_preds_to_df(raw,quantiles = None):
|
21 |
+
"""
|
22 |
+
raw is output of model.predict with return_index=True
|
23 |
+
quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
|
24 |
+
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
25 |
+
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
26 |
+
"""
|
27 |
+
index = raw[2]
|
28 |
+
preds = raw[0].prediction
|
29 |
+
dec_len = preds.shape[1]
|
30 |
+
n_quantiles = preds.shape[-1]
|
31 |
+
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
32 |
+
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
33 |
+
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
34 |
+
preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy())
|
35 |
+
if quantiles is not None:
|
36 |
+
preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})
|
37 |
+
|
38 |
+
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
39 |
+
return preds_df
|
40 |
+
|
41 |
+
def prepare_dataset(parameters, df, rain, temperature, datepicker):
|
42 |
+
if rain != "Default":
|
43 |
+
df["MTXWTH_Day_precip"] = rain_mapping[rain]
|
44 |
+
|
45 |
+
df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
|
46 |
+
df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature
|
47 |
+
|
48 |
+
lowerbound = datepicker - datetime.timedelta(days = 35)
|
49 |
+
upperbound = datepicker + datetime.timedelta(days = 30)
|
50 |
+
|
51 |
+
df = df.loc[(df["Date"]>lowerbound) & (df["Date"]<=upperbound)]
|
52 |
+
|
53 |
+
df = TimeSeriesDataSet.from_parameters(parameters, df)
|
54 |
+
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
55 |
+
|
56 |
+
def predict(model, dataloader):
|
57 |
+
return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
|
58 |
+
|
59 |
+
## Initiate Data
|
60 |
+
with open('data/parameters.pkl', 'rb') as f:
|
61 |
+
parameters = pickle.load(f)
|
62 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt')
|
63 |
+
|
64 |
+
df = pd.read_pickle('data/test_data.pkl')
|
65 |
+
df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
|
66 |
+
|
67 |
+
rain_mapping = {
|
68 |
+
"Yes" : 1,
|
69 |
+
"No" : 0
|
70 |
+
}
|
71 |
+
|
72 |
+
# Start App
|
73 |
+
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
74 |
+
|
75 |
+
st.markdown(body = """
|
76 |
+
### Abstract
|
77 |
+
Multi-horizon forecasting often contains a complex mix of inputs – including
|
78 |
+
static (i.e. time-invariant) covariates, known future inputs, and other exogenous
|
79 |
+
time series that are only observed in the past – without any prior information
|
80 |
+
on how they interact with the target. Several deep learning methods have been
|
81 |
+
proposed, but they are typically ‘black-box’ models which do not shed light on
|
82 |
+
how they use the full range of inputs present in practical scenarios. In this pa-
|
83 |
+
per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
|
84 |
+
based architecture which combines high-performance multi-horizon forecasting
|
85 |
+
with interpretable insights into temporal dynamics. To learn temporal rela-
|
86 |
+
tionships at different scales, TFT uses recurrent layers for local processing and
|
87 |
+
interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
|
88 |
+
cialized components to select relevant features and a series of gating layers to
|
89 |
+
suppress unnecessary components, enabling high performance in a wide range of
|
90 |
+
scenarios. On a variety of real-world datasets, we demonstrate significant per-
|
91 |
+
formance improvements over existing benchmarks, and showcase three practical
|
92 |
+
interpretability use cases of TFT.
|
93 |
+
""")
|
94 |
+
|
95 |
+
rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
|
96 |
+
|
97 |
+
temperature = st.slider('Change in Temperature', min_value=-10, max_value=+10, value=0, step=0.25)
|
98 |
+
|
99 |
+
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
100 |
+
|
101 |
+
arr = np.random.normal(1, 1, size=100)
|
102 |
+
fig, ax = plt.subplots()
|
103 |
+
ax.hist(arr, bins=20)
|
104 |
+
|
105 |
+
st.pyplot(fig)
|
106 |
+
|
107 |
+
st.button("Forecast Sales", type="primary") #on_click=None,
|
108 |
+
|
109 |
+
# %%
|
110 |
+
preds = raw_preds_to_df(out, quantiles = None)
|
111 |
+
|
112 |
+
preds = preds.merge(data_selected[['time_idx','Group','Branch','sales','weight','Date','MTXWTH_Day_precip','MTXWTH_Temp_max','MTXWTH_Temp_min']],how='left',left_on=['pred_idx','Group','Branch'],right_on=['time_idx','Group','Branch'])
|
113 |
+
preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
114 |
+
preds.drop(columns=['time_idx_y'],inplace=True)
|
115 |
+
|
116 |
+
|
app.py
CHANGED
@@ -66,7 +66,7 @@ df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
|
|
66 |
|
67 |
rain_mapping = {
|
68 |
"Yes" : 1,
|
69 |
-
"No" :
|
70 |
}
|
71 |
|
72 |
# Start App
|
|
|
66 |
|
67 |
rain_mapping = {
|
68 |
"Yes" : 1,
|
69 |
+
"No" : 0
|
70 |
}
|
71 |
|
72 |
# Start App
|