leuschnm commited on
Commit
add096e
·
1 Parent(s): 31f456a
Files changed (2) hide show
  1. app.py +17 -8
  2. data/image.png +0 -0
app.py CHANGED
@@ -16,6 +16,8 @@ from pytorch_forecasting import (
16
  TemporalFusionTransformer,
17
  )
18
 
 
 
19
  ## Functions
20
  def raw_preds_to_df(raw,quantiles = None):
21
  """
@@ -38,7 +40,6 @@ def raw_preds_to_df(raw,quantiles = None):
38
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
39
  return preds_df
40
 
41
- #@st.cache_data
42
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
43
  if rain != "Default":
44
  df["MTXWTH_Day_precip"] = mapping[rain]
@@ -54,13 +55,11 @@ def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
54
  df = TimeSeriesDataSet.from_parameters(_parameters, df)
55
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
56
 
57
- #@st.cache_data
58
  def predict(_model, _dataloader, datepicker):
59
  out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
60
  preds = raw_preds_to_df(out, quantiles = None)
61
  return preds[["pred_idx", "Group", "pred"]]
62
 
63
- #@st.cache_data
64
  def adjust_data_for_plot(df, preds):
65
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
66
  df = df[~df["pred"].isna()]
@@ -111,7 +110,11 @@ def init_model():
111
  def main():
112
  # Start App
113
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
114
-
 
 
 
 
115
  st.markdown(body = """
116
  ### Abstract
117
  Multi-horizon forecasting often contains a complex mix of inputs – including
@@ -132,6 +135,10 @@ def main():
132
  interpretability use cases of TFT.
133
 
134
  ### Experiments
 
 
 
 
135
  """)
136
 
137
  try:
@@ -158,12 +165,14 @@ def main():
158
 
159
  data_plot = adjust_data_for_plot(df.copy(), preds)
160
  fig, _ = generate_plot(data_plot)
161
-
162
- st.pyplot(fig)
163
- rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
164
- temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
165
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
166
 
 
 
 
 
 
167
  if __name__ == '__main__':
168
  main()
169
 
 
16
  TemporalFusionTransformer,
17
  )
18
 
19
+ from PIL import Image
20
+
21
  ## Functions
22
  def raw_preds_to_df(raw,quantiles = None):
23
  """
 
40
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
41
  return preds_df
42
 
 
43
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
44
  if rain != "Default":
45
  df["MTXWTH_Day_precip"] = mapping[rain]
 
55
  df = TimeSeriesDataSet.from_parameters(_parameters, df)
56
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
57
 
 
58
  def predict(_model, _dataloader, datepicker):
59
  out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
60
  preds = raw_preds_to_df(out, quantiles = None)
61
  return preds[["pred_idx", "Group", "pred"]]
62
 
 
63
  def adjust_data_for_plot(df, preds):
64
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
65
  df = df[~df["pred"].isna()]
 
110
  def main():
111
  # Start App
112
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
113
+
114
+
115
+ image = Image.open('data/image.png')
116
+
117
+ st.image(image, caption='Coding.Waterkant Festival for AI')
118
  st.markdown(body = """
119
  ### Abstract
120
  Multi-horizon forecasting often contains a complex mix of inputs – including
 
135
  interpretability use cases of TFT.
136
 
137
  ### Experiments
138
+ We implemented TFT for sales multi-horizon sales forecast during Coding.Waterkant.
139
+ Please try our implementation and adjust some of the training data.
140
+
141
+ Adjustments to the model and extention with Quantile forecasts are coming soon ;)
142
  """)
143
 
144
  try:
 
165
 
166
  data_plot = adjust_data_for_plot(df.copy(), preds)
167
  fig, _ = generate_plot(data_plot)
168
+
 
 
 
169
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date")
170
 
171
+ st.pyplot(fig)
172
+
173
+ temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature")
174
+ rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
175
+
176
  if __name__ == '__main__':
177
  main()
178
 
data/image.png ADDED