leuschnm commited on
Commit
3a5b465
·
1 Parent(s): dd58832
Files changed (1) hide show
  1. app.py +74 -69
app.py CHANGED
@@ -56,75 +56,7 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
56
  def predict(model, dataloader):
57
  return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
58
 
59
- ## Initiate Data
60
- with open('data/parameters.pkl', 'rb') as f:
61
- parameters = pickle.load(f)
62
- model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
63
-
64
- df = pd.read_pickle('data/test_data.pkl')
65
- df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
66
-
67
- rain_mapping = {
68
- "Yes" : 1,
69
- "No" : 0
70
- }
71
-
72
- # Start App
73
- st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
74
-
75
- st.markdown(body = """
76
- ### Abstract
77
- Multi-horizon forecasting often contains a complex mix of inputs – including
78
- static (i.e. time-invariant) covariates, known future inputs, and other exogenous
79
- time series that are only observed in the past – without any prior information
80
- on how they interact with the target. Several deep learning methods have been
81
- proposed, but they are typically ‘black-box’ models which do not shed light on
82
- how they use the full range of inputs present in practical scenarios. In this pa-
83
- per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
84
- based architecture which combines high-performance multi-horizon forecasting
85
- with interpretable insights into temporal dynamics. To learn temporal rela-
86
- tionships at different scales, TFT uses recurrent layers for local processing and
87
- interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
88
- cialized components to select relevant features and a series of gating layers to
89
- suppress unnecessary components, enabling high performance in a wide range of
90
- scenarios. On a variety of real-world datasets, we demonstrate significant per-
91
- formance improvements over existing benchmarks, and showcase three practical
92
- interpretability use cases of TFT.
93
-
94
- ### Experiments
95
- """)
96
-
97
- rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
98
-
99
- temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
100
-
101
- datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
102
-
103
- fig, axs = plt.subplots(2, 2, figsize=(8, 6))
104
-
105
- # Plot scatter plots for each group
106
- axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
107
- axs[0, 0].set_title('Article Group 1')
108
-
109
- axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
110
- axs[0, 1].set_title('Article Group 2')
111
-
112
- axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
113
- axs[1, 0].set_title('Article Group 3')
114
-
115
- axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
116
- axs[1, 1].set_title('Article Group 4')
117
-
118
- # Adjust spacing between subplots
119
- plt.tight_layout()
120
-
121
- for ax in axs.flat:
122
- ax.set_xlim(df['Date'].min(), df['Date'].max())
123
- ax.set_ylim(df['sales'].min(), df['sales'].max())
124
-
125
- st.pyplot(fig)
126
-
127
- st.button("Forecast Sales", type="primary") #on_click=None,
128
 
129
  # %%
130
  #preds = raw_preds_to_df(out, quantiles = None)
@@ -133,4 +65,77 @@ st.button("Forecast Sales", type="primary") #on_click=None,
133
  #preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
134
  #preds.drop(columns=['time_idx_y'],inplace=True)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
 
56
  def predict(model, dataloader):
57
  return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
58
 
59
+ #on_click=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # %%
62
  #preds = raw_preds_to_df(out, quantiles = None)
 
65
  #preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
66
  #preds.drop(columns=['time_idx_y'],inplace=True)
67
 
68
+ def main():
69
+ ## Initiate Data
70
+ with open('data/parameters.pkl', 'rb') as f:
71
+ parameters = pickle.load(f)
72
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
73
+
74
+ df = pd.read_pickle('data/test_data.pkl')
75
+ df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
76
+
77
+ rain_mapping = {
78
+ "Yes" : 1,
79
+ "No" : 0
80
+ }
81
+
82
+ # Start App
83
+ st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
84
+
85
+ st.markdown(body = """
86
+ ### Abstract
87
+ Multi-horizon forecasting often contains a complex mix of inputs – including
88
+ static (i.e. time-invariant) covariates, known future inputs, and other exogenous
89
+ time series that are only observed in the past – without any prior information
90
+ on how they interact with the target. Several deep learning methods have been
91
+ proposed, but they are typically ‘black-box’ models which do not shed light on
92
+ how they use the full range of inputs present in practical scenarios. In this pa-
93
+ per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
94
+ based architecture which combines high-performance multi-horizon forecasting
95
+ with interpretable insights into temporal dynamics. To learn temporal rela-
96
+ tionships at different scales, TFT uses recurrent layers for local processing and
97
+ interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
98
+ cialized components to select relevant features and a series of gating layers to
99
+ suppress unnecessary components, enabling high performance in a wide range of
100
+ scenarios. On a variety of real-world datasets, we demonstrate significant per-
101
+ formance improvements over existing benchmarks, and showcase three practical
102
+ interpretability use cases of TFT.
103
+
104
+ ### Experiments
105
+ """)
106
+ st.write(df.head(5))
107
+ rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
108
+
109
+ temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25)
110
+
111
+ datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
112
+
113
+ fig, axs = plt.subplots(2, 2, figsize=(8, 6))
114
+
115
+ # Plot scatter plots for each group
116
+ axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
117
+ axs[0, 0].set_title('Article Group 1')
118
+
119
+ axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
120
+ axs[0, 1].set_title('Article Group 2')
121
+
122
+ axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
123
+ axs[1, 0].set_title('Article Group 3')
124
+
125
+ axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
126
+ axs[1, 1].set_title('Article Group 4')
127
+
128
+ # Adjust spacing between subplots
129
+ plt.tight_layout()
130
+
131
+ for ax in axs.flat:
132
+ ax.set_xlim(df['Date'].min(), df['Date'].max())
133
+ ax.set_ylim(df['sales'].min(), df['sales'].max())
134
+
135
+ st.pyplot(fig)
136
+
137
+ st.button("Forecast Sales", type="primary")
138
+
139
+ if __name__ == '__main__':
140
+ main()
141