File size: 4,888 Bytes
c850717
 
 
 
 
 
 
 
 
 
 
962d9e1
c850717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e761553
c850717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00e684b
c850717
 
 
 
00e684b
c850717
 
 
 
 
 
fafd428
c850717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
## Imports
import pickle
import warnings
import streamlit as st
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime

import torch
from torch.distributions import Normal
from pytorch_forecasting import (
    TimeSeriesDataSet,
    TemporalFusionTransformer,
)

## Functions 
def raw_preds_to_df(raw,quantiles = None):
    """
    raw is output of model.predict with return_index=True
    quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
    in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
    pred_idx the index of the predicted date i.e. time_idx + h - 1
    """
    index = raw[2]
    preds = raw[0].prediction
    dec_len = preds.shape[1]
    n_quantiles = preds.shape[-1]
    preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
    preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
    preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
    preds_df = preds_df.assign(pred=preds.flatten().numpy())
    if quantiles is not None:
        preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})

    preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
    return preds_df

def prepare_dataset(parameters, df, rain, temperature, datepicker):
    if rain != "Default":
        df["MTXWTH_Day_precip"] = rain_mapping[rain]
    
    df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
    df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature

    lowerbound = datepicker - datetime.timedelta(days = 35) 
    upperbound = datepicker + datetime.timedelta(days = 30) 

    df = df.loc[(df["Date"]>lowerbound) & (df["Date"]<=upperbound)]
    
    df = TimeSeriesDataSet.from_parameters(parameters, df)
    return df.to_dataloader(train=False, batch_size=256,num_workers = 0)

def predict(model, dataloader): 
    return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
    
## Initiate Data
with open('data/parameters.pkl', 'rb') as f:
    parameters = pickle.load(f)
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))

df = pd.read_pickle('data/test_data.pkl')
df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]

rain_mapping = {
    "Yes" : 1,
    "No" : 0
}

# Start App
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")

st.markdown(body = """
    ### Abstract
    Multi-horizon forecasting often contains a complex mix of inputs – including
    static (i.e. time-invariant) covariates, known future inputs, and other exogenous
    time series that are only observed in the past – without any prior information
    on how they interact with the target. Several deep learning methods have been
    proposed, but they are typically ‘black-box’ models which do not shed light on
    how they use the full range of inputs present in practical scenarios. In this pa-
    per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
    based architecture which combines high-performance multi-horizon forecasting
    with interpretable insights into temporal dynamics. To learn temporal rela-
    tionships at different scales, TFT uses recurrent layers for local processing and
    interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
    cialized components to select relevant features and a series of gating layers to
    suppress unnecessary components, enabling high performance in a wide range of
    scenarios. On a variety of real-world datasets, we demonstrate significant per-
    formance improvements over existing benchmarks, and showcase three practical
    interpretability use cases of TFT.
""")

rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))

temperature = st.slider('Change in Temperature', min_value=-10, max_value=+10, value=0, step=0.25)

datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))

arr = np.random.normal(1, 1, size=100)
fig, ax = plt.subplots()
ax.hist(arr, bins=20)

st.pyplot(fig)

st.button("Forecast Sales", type="primary") #on_click=None,

# %%
preds = raw_preds_to_df(out, quantiles = None)

preds = preds.merge(data_selected[['time_idx','Group','Branch','sales','weight','Date','MTXWTH_Day_precip','MTXWTH_Temp_max','MTXWTH_Temp_min']],how='left',left_on=['pred_idx','Group','Branch'],right_on=['time_idx','Group','Branch'])
preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
preds.drop(columns=['time_idx_y'],inplace=True)