leuschnm commited on
Commit
a49f957
·
1 Parent(s): c25a5e4

change conversion

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -96,7 +96,7 @@ def generate_plot(df):
96
 
97
  @st.cache_data
98
  def load_data():
99
- with open('data/parameters_q.pkl', 'rb') as f:
100
  parameters = pickle.load(f)
101
  df = pd.read_pickle('data/test_data.pkl')
102
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
@@ -104,7 +104,7 @@ def load_data():
104
 
105
  @st.cache_resource
106
  def init_model():
107
- model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
108
  return model
109
 
110
  def main():
 
96
 
97
  @st.cache_data
98
  def load_data():
99
+ with open('data/parameters.pkl', 'rb') as f:
100
  parameters = pickle.load(f)
101
  df = pd.read_pickle('data/test_data.pkl')
102
  df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
 
104
 
105
  @st.cache_resource
106
  def init_model():
107
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
108
  return model
109
 
110
  def main():