leuschnm commited on
Commit
be7c874
·
1 Parent(s): 2c75212

change conversion

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -26,10 +26,11 @@ def raw_preds_to_df(raw, quantiles = None):
26
  in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
27
  pred_idx the index of the predicted date i.e. time_idx + h - 1
28
  """
29
- index = raw.index
30
- preds = raw.output.prediction
31
- dec_len = raw.output.prediction.shape[1]
32
- n_quantiles = preds.shape[-1]
 
33
  preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
34
  preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
35
  preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
@@ -39,7 +40,7 @@ def raw_preds_to_df(raw, quantiles = None):
39
 
40
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
41
  return preds_df
42
-
43
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
44
  if rain != "Default":
45
  df["MTXWTH_Day_precip"] = mapping[rain]
@@ -96,7 +97,7 @@ def generate_plot(df):
96
 
97
  @st.cache_data
98
  def load_data():
99
- with open('data/parameters.pkl', 'rb') as f:
100
  parameters = pickle.load(f)
101
  df = pd.read_pickle('data/test_data.pkl')
102
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
@@ -104,7 +105,7 @@ def load_data():
104
 
105
  @st.cache_resource
106
  def init_model():
107
- model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
108
  return model
109
 
110
  def main():
 
26
  in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
27
  pred_idx the index of the predicted date i.e. time_idx + h - 1
28
  """
29
+ index = raw[2]
30
+ output = raw[0]
31
+ preds = output.prediction
32
+ dec_len = output.prediction.shape[1]
33
+ n_quantiles = output.prediction.shape[-1]
34
  preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
35
  preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
36
  preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
 
40
 
41
  preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
42
  return preds_df
43
+
44
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
45
  if rain != "Default":
46
  df["MTXWTH_Day_precip"] = mapping[rain]
 
97
 
98
  @st.cache_data
99
  def load_data():
100
+ with open('data/parameters_q.pkl', 'rb') as f:
101
  parameters = pickle.load(f)
102
  df = pd.read_pickle('data/test_data.pkl')
103
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
 
105
 
106
  @st.cache_resource
107
  def init_model():
108
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
109
  return model
110
 
111
  def main():