leuschnm commited on
Commit
00e684b
·
1 Parent(s): b3d44f3
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +2 -2
  2. app.py +2 -2
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -54,12 +54,12 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  def predict(model, dataloader):
57
- return model.cpu().predict(dataloader, mode="raw", return_x=True, return_index=True)
58
 
59
  ## Initiate Data
60
  with open('data/parameters.pkl', 'rb') as f:
61
  parameters = pickle.load(f)
62
- model = TemporalFusionTransformer.cpu().load_from_checkpoint('model/tft_check.ckpt')
63
 
64
  df = pd.read_pickle('data/test_data.pkl')
65
  df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
 
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  def predict(model, dataloader):
57
+ return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
58
 
59
  ## Initiate Data
60
  with open('data/parameters.pkl', 'rb') as f:
61
  parameters = pickle.load(f)
62
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
63
 
64
  df = pd.read_pickle('data/test_data.pkl')
65
  df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
app.py CHANGED
@@ -54,12 +54,12 @@ def prepare_dataset(parameters, df, rain, temperature, datepicker):
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  def predict(model, dataloader):
57
- return model.cpu().predict(dataloader, mode="raw", return_x=True, return_index=True)
58
 
59
  ## Initiate Data
60
  with open('data/parameters.pkl', 'rb') as f:
61
  parameters = pickle.load(f)
62
- model = TemporalFusionTransformer.cpu().load_from_checkpoint('model/tft_check.ckpt')
63
 
64
  df = pd.read_pickle('data/test_data.pkl')
65
  df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
 
54
  return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
 
56
  def predict(model, dataloader):
57
+ return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
58
 
59
  ## Initiate Data
60
  with open('data/parameters.pkl', 'rb') as f:
61
  parameters = pickle.load(f)
62
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
63
 
64
  df = pd.read_pickle('data/test_data.pkl')
65
  df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]