Spaces:
Runtime error
Runtime error
bug fix
Browse files
app.py
CHANGED
@@ -85,21 +85,29 @@ def generate_plot(df, dates, preds):
|
|
85 |
# Adjust spacing between subplots
|
86 |
plt.tight_layout()
|
87 |
|
88 |
-
for ax in axs.flat:
|
89 |
-
|
90 |
-
|
91 |
|
92 |
st.pyplot(fig)
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
with open('data/parameters.pkl', 'rb') as f:
|
98 |
parameters = pickle.load(f)
|
99 |
-
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
100 |
-
|
101 |
df = pd.read_pickle('data/test_data.pkl')
|
102 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
# Start App
|
105 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
|
|
85 |
# Adjust spacing between subplots
|
86 |
plt.tight_layout()
|
87 |
|
88 |
+
#for ax in axs.flat:
|
89 |
+
# ax.set_xlim(df['Date'].min(), df['Date'].max())
|
90 |
+
# ax.set_ylim(df['sales'].min(), df['sales'].max())
|
91 |
|
92 |
st.pyplot(fig)
|
|
|
93 |
|
94 |
+
@st.cache_data
|
95 |
+
def load_data():
|
96 |
with open('data/parameters.pkl', 'rb') as f:
|
97 |
parameters = pickle.load(f)
|
|
|
|
|
98 |
df = pd.read_pickle('data/test_data.pkl')
|
99 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
100 |
+
return parameters, df
|
101 |
+
|
102 |
+
@st.cache_resource
|
103 |
+
def init_model():
|
104 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
105 |
+
return model
|
106 |
+
|
107 |
+
def main():
|
108 |
+
## Initiate Data
|
109 |
+
parameters, df = load_data()
|
110 |
+
model = init_model()
|
111 |
|
112 |
# Start App
|
113 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|