leuschnm commited on
Commit
769a99c
·
1 Parent(s): d526bda

change conversion

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -19,7 +19,7 @@ from pytorch_forecasting import (
19
  from PIL import Image
20
 
21
  ## Functions
22
- def raw_preds_to_df(raw, idx_offset, quantiles = None):
23
  """
24
  raw is output of model.predict with return_index=True
25
  quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
@@ -31,15 +31,13 @@ def raw_preds_to_df(raw, idx_offset, quantiles = None):
31
  dec_len = raw.output.prediction.shape[1]
32
  n_quantiles = preds.shape[-1]
33
  preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
34
- preds_df = preds_df.assign(Horizon=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
35
- preds_df = preds_df.assign(Quantile=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
36
- preds_df = preds_df.assign(Prediction=preds.flatten().cpu().numpy())
37
  if quantiles is not None:
38
- preds_df['Quantile'] = preds_df['Quantile'].map({i:q for i,q in enumerate(quantiles)})
39
 
40
- preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['Horizon'] - 1
41
- preds_df['Date'] = pd.to_datetime(idx_offset)
42
- preds_df['Date'] = preds_df['Date'] + preds_df['pred_idx'].apply(pd.DateOffset)
43
  return preds_df
44
 
45
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
@@ -57,10 +55,9 @@ def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
57
  df = TimeSeriesDataSet.from_parameters(_parameters, df)
58
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
59
 
60
- def predict(_model, _dataloader, datepicker, df):
61
- out = _model.predict(_dataloader, mode="raw", return_index=True)# return_x=True,
62
- first_date = min(df["Date"])
63
- preds = raw_preds_to_df(raw = out, idx_offset = first_date)
64
  return preds[["pred_idx", "Group", "pred"]]
65
 
66
  def adjust_data_for_plot(df, preds):
@@ -157,7 +154,7 @@ def main():
157
  rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
158
 
159
  dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
160
- preds = predict(model, dataloader, st.session_state.date, df)
161
 
162
  data_plot = adjust_data_for_plot(df.copy(), preds)
163
  fig, _ = generate_plot(data_plot)
 
19
  from PIL import Image
20
 
21
  ## Functions
22
+ def raw_preds_to_df(raw, quantiles = None):
23
  """
24
  raw is output of model.predict with return_index=True
25
  quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
 
31
  dec_len = raw.output.prediction.shape[1]
32
  n_quantiles = preds.shape[-1]
33
  preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
34
+ preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
35
+ preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
36
+ preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy())
37
  if quantiles is not None:
38
+ preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})
39
 
40
+ preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
 
 
41
  return preds_df
42
 
43
  def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
 
55
  df = TimeSeriesDataSet.from_parameters(_parameters, df)
56
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
57
 
58
+ def predict(_model, _dataloader):
59
+ out = model.predict(_dataloader, mode="raw", return_x=True,return_index=True, trainer_kwargs=dict(accelerator="cpu"))
60
+ preds = raw_preds_to_df(out)
 
61
  return preds[["pred_idx", "Group", "pred"]]
62
 
63
  def adjust_data_for_plot(df, preds):
 
154
  rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
155
 
156
  dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
157
+ preds = predict(model, dataloader)
158
 
159
  data_plot = adjust_data_for_plot(df.copy(), preds)
160
  fig, _ = generate_plot(data_plot)