File size: 7,610 Bytes
c850717
 
 
 
 
 
 
 
 
 
 
962d9e1
c850717
 
 
 
 
 
add096e
 
c850717
769a99c
c850717
 
 
 
 
 
be7c874
 
 
 
 
c850717
769a99c
 
 
c850717
769a99c
c850717
769a99c
c850717
be7c874
4c1e191
c850717
0ebac9a
c850717
 
 
 
 
 
 
bb5ae00
c850717
d23039a
dc44445
c850717
2c75212
 
 
9f8b254
c850717
7da7252
83bf3df
072ce3d
7da7252
 
0ebac9a
 
ac3bcbd
0ebac9a
7da7252
89c3cca
a4f898a
ed03da7
0ebac9a
87d051f
a6e9087
a4f898a
ed03da7
89c3cca
87d051f
89c3cca
a4f898a
ed03da7
89c3cca
87d051f
89c3cca
a4f898a
7da7252
ed03da7
87d051f
17832c5
a6e9087
90de3d0
89c3cca
f2ea1f4
 
be7c874
3a5b465
 
89c3cca
f2ea1f4
 
 
 
be7c874
f2ea1f4
 
 
3a5b465
 
add096e
 
 
6bd273b
6a5173b
6bd273b
6a5173b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd273b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552e2fb
6bd273b
 
769a99c
6bd273b
 
 
 
 
 
3a5b465
f28e66f
 
50d06b3
33d942e
3a5b465
 
 
c850717
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
## Imports
import pickle
import warnings
import streamlit as st
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime

import torch
from torch.distributions import Normal
from pytorch_forecasting import (
    TimeSeriesDataSet,
    TemporalFusionTransformer,
)

from PIL import Image

## Functions 
def raw_preds_to_df(raw, quantiles = None):
    """
    raw is output of model.predict with return_index=True
    quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
    in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
    pred_idx the index of the predicted date i.e. time_idx + h - 1
    """
    index = raw[2]
    output = raw[0]
    preds = output.prediction
    dec_len = output.prediction.shape[1]
    n_quantiles = output.prediction.shape[-1]
    preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
    preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
    preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
    preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy())
    if quantiles is not None:
        preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})

    preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
    return preds_df
    
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
    if rain != "Default":
        df["MTXWTH_Day_precip"] = mapping[rain]
    
    df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
    df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature

    lowerbound = datepicker - datetime.timedelta(days = 35) 
    upperbound = datepicker + datetime.timedelta(days = 30) 

    df = df.loc[(df["Date"].dt.date>lowerbound) & (df["Date"].dt.date<=upperbound)]
    
    df = TimeSeriesDataSet.from_parameters(_parameters, df)
    return df.to_dataloader(train=False, batch_size=256,num_workers = 0)

def predict(model, dataloader):
    out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)#, trainer_kwargs=dict(accelerator="cpu"))
    preds = raw_preds_to_df(out, quantiles = None)
    return preds[["pred_idx", "Group", "pred"]]

def adjust_data_for_plot(df, preds):
    df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
    df = df[~df["pred"].isna()]
    df["sales"] = df["sales"].replace(0.0, np.nan)
    return df


def generate_plot(df):
    fig, axs = plt.subplots(2, 2, figsize=(8, 6))
    
    # Plot scatter plots for each group
    axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey')
    axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
    axs[0, 0].set_title('Article Group 1')
    axs[0, 0].xaxis.set_tick_params(rotation=45)

    axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey')
    axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
    axs[0, 1].set_title('Article Group 2')
    axs[0, 1].xaxis.set_tick_params(rotation=45)
    
    axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='grey')
    axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
    axs[1, 0].set_title('Article Group 3')
    axs[1, 0].xaxis.set_tick_params(rotation=45)
    
    axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey')
    axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
    axs[1, 1].set_title('Article Group 4')
    axs[1, 1].xaxis.set_tick_params(rotation=45)

    plt.tight_layout()
    return fig, axs

@st.cache_data
def load_data():
    with open('data/parameters_q.pkl', 'rb') as f:
        parameters = pickle.load(f)
    df = pd.read_pickle('data/test_data.pkl')
    df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
    return parameters, df

@st.cache_resource
def init_model():
    model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
    return model
    
def main():
    # Start App
    st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")

    image = Image.open('data/image.png')
    st.image(image, caption='Coding.Waterkant Festival for AI')

    
    st.markdown(body = """
        ### Abstract
        Multi-horizon forecasting often contains a complex mix of inputs – including
        static (i.e. time-invariant) covariates, known future inputs, and other exogenous
        time series that are only observed in the past – without any prior information
        on how they interact with the target. Several deep learning methods have been
        proposed, but they are typically ‘black-box’ models which do not shed light on
        how they use the full range of inputs present in practical scenarios. In this pa-
        per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
        based architecture which combines high-performance multi-horizon forecasting
        with interpretable insights into temporal dynamics. To learn temporal rela-
        tionships at different scales, TFT uses recurrent layers for local processing and
        interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
        cialized components to select relevant features and a series of gating layers to
        suppress unnecessary components, enabling high performance in a wide range of
        scenarios. On a variety of real-world datasets, we demonstrate significant per-
        formance improvements over existing benchmarks, and showcase three practical
        interpretability use cases of TFT.
    
        ### Experiments
        We implemented TFT for sales multi-horizon sales forecast during Coding.Waterkant. 
        Please try our implementation and adjust some of the training data.

        Adjustments to the model and extention with Quantile forecast are coming soon ;) 
    """)

    RAIN_MAPPING = {
        "Yes" : 1,
        "No" : 0
    }
    
    parameters, df = load_data()
    model = init_model()
    
    datepicker = st.date_input("Start of Forecast", value = datetime.date(2022, 10, 24) ,min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
    temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
    rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")

    dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
    preds = predict(model, dataloader)

    data_plot = adjust_data_for_plot(df.copy(), preds)
    fig, _ = generate_plot(data_plot)

    st.pyplot(fig)
    
    st.markdown(body = """
        ### Sources
        **Paper:** [Bryan Lim et al. in Temporal Fusion Transformers (TFT)](https://arxiv.org/abs/1912.09363). <br>
        **Demo created by:** [MalteLeuschner - leuschnm](https://github.com/MalteLeuschner)
    """, unsafe_allow_html = True)
    
if __name__ == '__main__':
    main()