Spaces:
Runtime error
Runtime error
File size: 7,610 Bytes
c850717 962d9e1 c850717 add096e c850717 769a99c c850717 be7c874 c850717 769a99c c850717 769a99c c850717 769a99c c850717 be7c874 4c1e191 c850717 0ebac9a c850717 bb5ae00 c850717 d23039a dc44445 c850717 2c75212 9f8b254 c850717 7da7252 83bf3df 072ce3d 7da7252 0ebac9a ac3bcbd 0ebac9a 7da7252 89c3cca a4f898a ed03da7 0ebac9a 87d051f a6e9087 a4f898a ed03da7 89c3cca 87d051f 89c3cca a4f898a ed03da7 89c3cca 87d051f 89c3cca a4f898a 7da7252 ed03da7 87d051f 17832c5 a6e9087 90de3d0 89c3cca f2ea1f4 be7c874 3a5b465 89c3cca f2ea1f4 be7c874 f2ea1f4 3a5b465 add096e 6bd273b 6a5173b 6bd273b 6a5173b 6bd273b 552e2fb 6bd273b 769a99c 6bd273b 3a5b465 f28e66f 50d06b3 33d942e 3a5b465 c850717 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
## Imports
import pickle
import warnings
import streamlit as st
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
from torch.distributions import Normal
from pytorch_forecasting import (
TimeSeriesDataSet,
TemporalFusionTransformer,
)
from PIL import Image
## Functions
def raw_preds_to_df(raw, quantiles = None):
"""
raw is output of model.predict with return_index=True
quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
pred_idx the index of the predicted date i.e. time_idx + h - 1
"""
index = raw[2]
output = raw[0]
preds = output.prediction
dec_len = output.prediction.shape[1]
n_quantiles = output.prediction.shape[-1]
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy())
if quantiles is not None:
preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
return preds_df
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
if rain != "Default":
df["MTXWTH_Day_precip"] = mapping[rain]
df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature
lowerbound = datepicker - datetime.timedelta(days = 35)
upperbound = datepicker + datetime.timedelta(days = 30)
df = df.loc[(df["Date"].dt.date>lowerbound) & (df["Date"].dt.date<=upperbound)]
df = TimeSeriesDataSet.from_parameters(_parameters, df)
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
def predict(model, dataloader):
out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)#, trainer_kwargs=dict(accelerator="cpu"))
preds = raw_preds_to_df(out, quantiles = None)
return preds[["pred_idx", "Group", "pred"]]
def adjust_data_for_plot(df, preds):
df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
df = df[~df["pred"].isna()]
df["sales"] = df["sales"].replace(0.0, np.nan)
return df
def generate_plot(df):
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
# Plot scatter plots for each group
axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
axs[0, 0].set_title('Article Group 1')
axs[0, 0].xaxis.set_tick_params(rotation=45)
axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey')
axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
axs[0, 1].set_title('Article Group 2')
axs[0, 1].xaxis.set_tick_params(rotation=45)
axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='grey')
axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
axs[1, 0].set_title('Article Group 3')
axs[1, 0].xaxis.set_tick_params(rotation=45)
axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
axs[1, 1].set_title('Article Group 4')
axs[1, 1].xaxis.set_tick_params(rotation=45)
plt.tight_layout()
return fig, axs
@st.cache_data
def load_data():
with open('data/parameters_q.pkl', 'rb') as f:
parameters = pickle.load(f)
df = pd.read_pickle('data/test_data.pkl')
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
return parameters, df
@st.cache_resource
def init_model():
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
return model
def main():
# Start App
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
image = Image.open('data/image.png')
st.image(image, caption='Coding.Waterkant Festival for AI')
st.markdown(body = """
### Abstract
Multi-horizon forecasting often contains a complex mix of inputs – including
static (i.e. time-invariant) covariates, known future inputs, and other exogenous
time series that are only observed in the past – without any prior information
on how they interact with the target. Several deep learning methods have been
proposed, but they are typically ‘black-box’ models which do not shed light on
how they use the full range of inputs present in practical scenarios. In this pa-
per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
based architecture which combines high-performance multi-horizon forecasting
with interpretable insights into temporal dynamics. To learn temporal rela-
tionships at different scales, TFT uses recurrent layers for local processing and
interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
cialized components to select relevant features and a series of gating layers to
suppress unnecessary components, enabling high performance in a wide range of
scenarios. On a variety of real-world datasets, we demonstrate significant per-
formance improvements over existing benchmarks, and showcase three practical
interpretability use cases of TFT.
### Experiments
We implemented TFT for sales multi-horizon sales forecast during Coding.Waterkant.
Please try our implementation and adjust some of the training data.
Adjustments to the model and extention with Quantile forecast are coming soon ;)
""")
RAIN_MAPPING = {
"Yes" : 1,
"No" : 0
}
parameters, df = load_data()
model = init_model()
datepicker = st.date_input("Start of Forecast", value = datetime.date(2022, 10, 24) ,min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
preds = predict(model, dataloader)
data_plot = adjust_data_for_plot(df.copy(), preds)
fig, _ = generate_plot(data_plot)
st.pyplot(fig)
st.markdown(body = """
### Sources
**Paper:** [Bryan Lim et al. in Temporal Fusion Transformers (TFT)](https://arxiv.org/abs/1912.09363). <br>
**Demo created by:** [MalteLeuschner - leuschnm](https://github.com/MalteLeuschner)
""", unsafe_allow_html = True)
if __name__ == '__main__':
main()
|