leuschnm commited on
Commit
bb5ce46
·
1 Parent(s): c5196fd

change conversion

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -19,25 +19,27 @@ from pytorch_forecasting import (
19
  from PIL import Image
20
 
21
  ## Functions
22
- def raw_preds_to_df(raw,quantiles = None):
23
  """
24
  raw is output of model.predict with return_index=True
25
  quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
26
  in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
27
  pred_idx the index of the predicted date i.e. time_idx + h - 1
28
  """
29
- index = raw[2]
30
- preds = raw[0].prediction
31
- dec_len = preds.shape[1]
32
  n_quantiles = preds.shape[-1]
33
  preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
34
- preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
35
- preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
36
- preds_df = preds_df.assign(pred=preds.flatten().numpy())
37
  if quantiles is not None:
38
- preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})
39
 
40
- preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
 
 
41
  return preds_df
42
 
43
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
@@ -57,7 +59,8 @@ def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
57
 
58
  def predict(_model, _dataloader, datepicker):
59
  out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
60
- preds = raw_preds_to_df(out, quantiles = None)
 
61
  return preds[["pred_idx", "Group", "pred"]]
62
 
63
  def adjust_data_for_plot(df, preds):
@@ -96,7 +99,7 @@ def generate_plot(df):
96
 
97
  @st.cache_data
98
  def load_data():
99
- with open('data/parameters.pkl', 'rb') as f:
100
  parameters = pickle.load(f)
101
  df = pd.read_pickle('data/test_data.pkl')
102
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
@@ -104,7 +107,7 @@ def load_data():
104
 
105
  @st.cache_resource
106
  def init_model():
107
- model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
108
  return model
109
 
110
  def main():
 
19
  from PIL import Image
20
 
21
  ## Functions
22
+ def raw_preds_to_df(raw, idx_offset, quantiles = None):
23
  """
24
  raw is output of model.predict with return_index=True
25
  quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
26
  in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
27
  pred_idx the index of the predicted date i.e. time_idx + h - 1
28
  """
29
+ index = raw.index
30
+ preds = raw.output.prediction
31
+ dec_len = raw.output.prediction.shape[1]
32
  n_quantiles = preds.shape[-1]
33
  preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
34
+ preds_df = preds_df.assign(Horizon=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
35
+ preds_df = preds_df.assign(Quantile=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
36
+ preds_df = preds_df.assign(Prediction=preds.flatten().cpu().numpy())
37
  if quantiles is not None:
38
+ preds_df['Quantile'] = preds_df['Quantile'].map({i:q for i,q in enumerate(quantiles)})
39
 
40
+ preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['Horizon'] - 1
41
+ preds_df['Date'] = pd.to_datetime(idx_offset)
42
+ preds_df['Date'] = preds_df['Date'] + preds_df['pred_idx'].apply(pd.DateOffset)
43
  return preds_df
44
 
45
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
 
59
 
60
  def predict(_model, _dataloader, datepicker):
61
  out = _model.predict(_dataloader, mode="raw", return_x=True, return_index=True)
62
+ first_date = min(df["Date"])
63
+ preds = raw_preds_to_df(out, first_date)
64
  return preds[["pred_idx", "Group", "pred"]]
65
 
66
  def adjust_data_for_plot(df, preds):
 
99
 
100
  @st.cache_data
101
  def load_data():
102
+ with open('data/parameters_q.pkl', 'rb') as f:
103
  parameters = pickle.load(f)
104
  df = pd.read_pickle('data/test_data.pkl')
105
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
 
107
 
108
  @st.cache_resource
109
  def init_model():
110
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
111
  return model
112
 
113
  def main():