leuschnm commited on
Commit
f2ea1f4
·
1 Parent(s): d59394d
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -85,21 +85,29 @@ def generate_plot(df, dates, preds):
85
  # Adjust spacing between subplots
86
  plt.tight_layout()
87
 
88
- for ax in axs.flat:
89
- ax.set_xlim(df['Date'].min(), df['Date'].max())
90
- ax.set_ylim(df['sales'].min(), df['sales'].max())
91
 
92
  st.pyplot(fig)
93
-
94
 
95
- def main():
96
- ## Initiate Data
97
  with open('data/parameters.pkl', 'rb') as f:
98
  parameters = pickle.load(f)
99
- model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
100
-
101
  df = pd.read_pickle('data/test_data.pkl')
102
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # Start App
105
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
 
85
  # Adjust spacing between subplots
86
  plt.tight_layout()
87
 
88
+ #for ax in axs.flat:
89
+ # ax.set_xlim(df['Date'].min(), df['Date'].max())
90
+ # ax.set_ylim(df['sales'].min(), df['sales'].max())
91
 
92
  st.pyplot(fig)
 
93
 
94
+ @st.cache_data
95
+ def load_data():
96
  with open('data/parameters.pkl', 'rb') as f:
97
  parameters = pickle.load(f)
 
 
98
  df = pd.read_pickle('data/test_data.pkl')
99
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
100
+ return parameters, df
101
+
102
+ @st.cache_resource
103
+ def init_model():
104
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
105
+ return model
106
+
107
+ def main():
108
+ ## Initiate Data
109
+ parameters, df = load_data()
110
+ model = init_model()
111
 
112
  # Start App
113
  st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")