Spaces:
Runtime error
Runtime error
change conversion
Browse files
app.py
CHANGED
@@ -26,10 +26,11 @@ def raw_preds_to_df(raw, quantiles = None):
|
|
26 |
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
27 |
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
28 |
"""
|
29 |
-
index = raw
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
34 |
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
35 |
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
@@ -39,7 +40,7 @@ def raw_preds_to_df(raw, quantiles = None):
|
|
39 |
|
40 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
41 |
return preds_df
|
42 |
-
|
43 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
44 |
if rain != "Default":
|
45 |
df["MTXWTH_Day_precip"] = mapping[rain]
|
@@ -96,7 +97,7 @@ def generate_plot(df):
|
|
96 |
|
97 |
@st.cache_data
|
98 |
def load_data():
|
99 |
-
with open('data/
|
100 |
parameters = pickle.load(f)
|
101 |
df = pd.read_pickle('data/test_data.pkl')
|
102 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
@@ -104,7 +105,7 @@ def load_data():
|
|
104 |
|
105 |
@st.cache_resource
|
106 |
def init_model():
|
107 |
-
model = TemporalFusionTransformer.load_from_checkpoint('model/
|
108 |
return model
|
109 |
|
110 |
def main():
|
|
|
26 |
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
27 |
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
28 |
"""
|
29 |
+
index = raw[2]
|
30 |
+
output = raw[0]
|
31 |
+
preds = output.prediction
|
32 |
+
dec_len = output.prediction.shape[1]
|
33 |
+
n_quantiles = output.prediction.shape[-1]
|
34 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
35 |
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
36 |
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
|
|
40 |
|
41 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
42 |
return preds_df
|
43 |
+
|
44 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
45 |
if rain != "Default":
|
46 |
df["MTXWTH_Day_precip"] = mapping[rain]
|
|
|
97 |
|
98 |
@st.cache_data
|
99 |
def load_data():
|
100 |
+
with open('data/parameters_q.pkl', 'rb') as f:
|
101 |
parameters = pickle.load(f)
|
102 |
df = pd.read_pickle('data/test_data.pkl')
|
103 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
|
|
105 |
|
106 |
@st.cache_resource
|
107 |
def init_model():
|
108 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
|
109 |
return model
|
110 |
|
111 |
def main():
|