Spaces:
Runtime error
Runtime error
bug fix
Browse files
app.py
CHANGED
@@ -56,75 +56,7 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
|
|
56 |
def predict(model, dataloader):
|
57 |
return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
|
58 |
|
59 |
-
|
60 |
-
with open('data/parameters.pkl', 'rb') as f:
|
61 |
-
parameters = pickle.load(f)
|
62 |
-
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
63 |
-
|
64 |
-
df = pd.read_pickle('data/test_data.pkl')
|
65 |
-
df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
|
66 |
-
|
67 |
-
rain_mapping = {
|
68 |
-
"Yes" : 1,
|
69 |
-
"No" : 0
|
70 |
-
}
|
71 |
-
|
72 |
-
# Start App
|
73 |
-
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
74 |
-
|
75 |
-
st.markdown(body = """
|
76 |
-
### Abstract
|
77 |
-
Multi-horizon forecasting often contains a complex mix of inputs – including
|
78 |
-
static (i.e. time-invariant) covariates, known future inputs, and other exogenous
|
79 |
-
time series that are only observed in the past – without any prior information
|
80 |
-
on how they interact with the target. Several deep learning methods have been
|
81 |
-
proposed, but they are typically ‘black-box’ models which do not shed light on
|
82 |
-
how they use the full range of inputs present in practical scenarios. In this pa-
|
83 |
-
per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
|
84 |
-
based architecture which combines high-performance multi-horizon forecasting
|
85 |
-
with interpretable insights into temporal dynamics. To learn temporal rela-
|
86 |
-
tionships at different scales, TFT uses recurrent layers for local processing and
|
87 |
-
interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
|
88 |
-
cialized components to select relevant features and a series of gating layers to
|
89 |
-
suppress unnecessary components, enabling high performance in a wide range of
|
90 |
-
scenarios. On a variety of real-world datasets, we demonstrate significant per-
|
91 |
-
formance improvements over existing benchmarks, and showcase three practical
|
92 |
-
interpretability use cases of TFT.
|
93 |
-
|
94 |
-
### Experiments
|
95 |
-
""")
|
96 |
-
|
97 |
-
rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
|
98 |
-
|
99 |
-
temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
|
100 |
-
|
101 |
-
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
102 |
-
|
103 |
-
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
|
104 |
-
|
105 |
-
# Plot scatter plots for each group
|
106 |
-
axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
|
107 |
-
axs[0, 0].set_title('Article Group 1')
|
108 |
-
|
109 |
-
axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
|
110 |
-
axs[0, 1].set_title('Article Group 2')
|
111 |
-
|
112 |
-
axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
|
113 |
-
axs[1, 0].set_title('Article Group 3')
|
114 |
-
|
115 |
-
axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
|
116 |
-
axs[1, 1].set_title('Article Group 4')
|
117 |
-
|
118 |
-
# Adjust spacing between subplots
|
119 |
-
plt.tight_layout()
|
120 |
-
|
121 |
-
for ax in axs.flat:
|
122 |
-
ax.set_xlim(df['Date'].min(), df['Date'].max())
|
123 |
-
ax.set_ylim(df['sales'].min(), df['sales'].max())
|
124 |
-
|
125 |
-
st.pyplot(fig)
|
126 |
-
|
127 |
-
st.button("Forecast Sales", type="primary") #on_click=None,
|
128 |
|
129 |
# %%
|
130 |
#preds = raw_preds_to_df(out, quantiles = None)
|
@@ -133,4 +65,77 @@ st.button("Forecast Sales", type="primary") #on_click=None,
|
|
133 |
#preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
134 |
#preds.drop(columns=['time_idx_y'],inplace=True)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
|
|
56 |
def predict(model, dataloader):
|
57 |
return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
|
58 |
|
59 |
+
#on_click=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
# %%
|
62 |
#preds = raw_preds_to_df(out, quantiles = None)
|
|
|
65 |
#preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
66 |
#preds.drop(columns=['time_idx_y'],inplace=True)
|
67 |
|
68 |
+
def main():
|
69 |
+
## Initiate Data
|
70 |
+
with open('data/parameters.pkl', 'rb') as f:
|
71 |
+
parameters = pickle.load(f)
|
72 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
73 |
+
|
74 |
+
df = pd.read_pickle('data/test_data.pkl')
|
75 |
+
df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
|
76 |
+
|
77 |
+
rain_mapping = {
|
78 |
+
"Yes" : 1,
|
79 |
+
"No" : 0
|
80 |
+
}
|
81 |
+
|
82 |
+
# Start App
|
83 |
+
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
84 |
+
|
85 |
+
st.markdown(body = """
|
86 |
+
### Abstract
|
87 |
+
Multi-horizon forecasting often contains a complex mix of inputs – including
|
88 |
+
static (i.e. time-invariant) covariates, known future inputs, and other exogenous
|
89 |
+
time series that are only observed in the past – without any prior information
|
90 |
+
on how they interact with the target. Several deep learning methods have been
|
91 |
+
proposed, but they are typically ‘black-box’ models which do not shed light on
|
92 |
+
how they use the full range of inputs present in practical scenarios. In this pa-
|
93 |
+
per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
|
94 |
+
based architecture which combines high-performance multi-horizon forecasting
|
95 |
+
with interpretable insights into temporal dynamics. To learn temporal rela-
|
96 |
+
tionships at different scales, TFT uses recurrent layers for local processing and
|
97 |
+
interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
|
98 |
+
cialized components to select relevant features and a series of gating layers to
|
99 |
+
suppress unnecessary components, enabling high performance in a wide range of
|
100 |
+
scenarios. On a variety of real-world datasets, we demonstrate significant per-
|
101 |
+
formance improvements over existing benchmarks, and showcase three practical
|
102 |
+
interpretability use cases of TFT.
|
103 |
+
|
104 |
+
### Experiments
|
105 |
+
""")
|
106 |
+
st.write(df.head(5))
|
107 |
+
rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
|
108 |
+
|
109 |
+
temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
|
110 |
+
|
111 |
+
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
112 |
+
|
113 |
+
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
|
114 |
+
|
115 |
+
# Plot scatter plots for each group
|
116 |
+
axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
|
117 |
+
axs[0, 0].set_title('Article Group 1')
|
118 |
+
|
119 |
+
axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
|
120 |
+
axs[0, 1].set_title('Article Group 2')
|
121 |
+
|
122 |
+
axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
|
123 |
+
axs[1, 0].set_title('Article Group 3')
|
124 |
+
|
125 |
+
axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
|
126 |
+
axs[1, 1].set_title('Article Group 4')
|
127 |
+
|
128 |
+
# Adjust spacing between subplots
|
129 |
+
plt.tight_layout()
|
130 |
+
|
131 |
+
for ax in axs.flat:
|
132 |
+
ax.set_xlim(df['Date'].min(), df['Date'].max())
|
133 |
+
ax.set_ylim(df['sales'].min(), df['sales'].max())
|
134 |
+
|
135 |
+
st.pyplot(fig)
|
136 |
+
|
137 |
+
st.button("Forecast Sales", type="primary")
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
main()
|
141 |
|