Spaces:
Runtime error
Runtime error
change conversion
Browse files
app.py
CHANGED
@@ -19,25 +19,27 @@ from pytorch_forecasting import (
|
|
19 |
from PIL import Image
|
20 |
|
21 |
## Functions
|
22 |
-
def raw_preds_to_df(raw,quantiles = None):
|
23 |
"""
|
24 |
raw is output of model.predict with return_index=True
|
25 |
quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
|
26 |
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
27 |
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
28 |
"""
|
29 |
-
index = raw
|
30 |
-
preds = raw
|
31 |
-
dec_len =
|
32 |
n_quantiles = preds.shape[-1]
|
33 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
34 |
-
preds_df = preds_df.assign(
|
35 |
-
preds_df = preds_df.assign(
|
36 |
-
preds_df = preds_df.assign(
|
37 |
if quantiles is not None:
|
38 |
-
preds_df['
|
39 |
|
40 |
-
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['
|
|
|
|
|
41 |
return preds_df
|
42 |
|
43 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
@@ -57,7 +59,8 @@ def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
|
57 |
|
58 |
def predict(_model, _dataloader, datepicker):
|
59 |
out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
|
60 |
-
|
|
|
61 |
return preds[["pred_idx", "Group", "pred"]]
|
62 |
|
63 |
def adjust_data_for_plot(df, preds):
|
@@ -96,7 +99,7 @@ def generate_plot(df):
|
|
96 |
|
97 |
@st.cache_data
|
98 |
def load_data():
|
99 |
-
with open('data/
|
100 |
parameters = pickle.load(f)
|
101 |
df = pd.read_pickle('data/test_data.pkl')
|
102 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
@@ -104,7 +107,7 @@ def load_data():
|
|
104 |
|
105 |
@st.cache_resource
|
106 |
def init_model():
|
107 |
-
model = TemporalFusionTransformer.load_from_checkpoint('model/
|
108 |
return model
|
109 |
|
110 |
def main():
|
|
|
19 |
from PIL import Image
|
20 |
|
21 |
## Functions
|
22 |
+
def raw_preds_to_df(raw, idx_offset, quantiles = None):
|
23 |
"""
|
24 |
raw is output of model.predict with return_index=True
|
25 |
quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
|
26 |
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
27 |
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
28 |
"""
|
29 |
+
index = raw.index
|
30 |
+
preds = raw.output.prediction
|
31 |
+
dec_len = raw.output.prediction.shape[1]
|
32 |
n_quantiles = preds.shape[-1]
|
33 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
34 |
+
preds_df = preds_df.assign(Horizon=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
35 |
+
preds_df = preds_df.assign(Quantile=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
36 |
+
preds_df = preds_df.assign(Prediction=preds.flatten().cpu().numpy())
|
37 |
if quantiles is not None:
|
38 |
+
preds_df['Quantile'] = preds_df['Quantile'].map({i:q for i,q in enumerate(quantiles)})
|
39 |
|
40 |
+
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['Horizon'] - 1
|
41 |
+
preds_df['Date'] = pd.to_datetime(idx_offset)
|
42 |
+
preds_df['Date'] = preds_df['Date'] + preds_df['pred_idx'].apply(pd.DateOffset)
|
43 |
return preds_df
|
44 |
|
45 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
|
|
59 |
|
60 |
def predict(_model, _dataloader, datepicker):
|
61 |
out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
|
62 |
+
first_date = min(df["Date"])
|
63 |
+
preds = raw_preds_to_df(out, first_date)
|
64 |
return preds[["pred_idx", "Group", "pred"]]
|
65 |
|
66 |
def adjust_data_for_plot(df, preds):
|
|
|
99 |
|
100 |
@st.cache_data
|
101 |
def load_data():
|
102 |
+
with open('data/parameters_q.pkl', 'rb') as f:
|
103 |
parameters = pickle.load(f)
|
104 |
df = pd.read_pickle('data/test_data.pkl')
|
105 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
|
|
107 |
|
108 |
@st.cache_resource
|
109 |
def init_model():
|
110 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
|
111 |
return model
|
112 |
|
113 |
def main():
|