transpolymer commited on
Commit
cf6ea36
·
verified ·
1 Parent(s): 6337dea

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +1 -4
prediction.py CHANGED
@@ -68,12 +68,9 @@ class TransformerRegressor(nn.Module):
68
  # ------------------------ Load Model ------------------------
69
  @st.cache_resource
70
  def load_model():
71
- model = TransformerRegressor()
72
- state_dict = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
73
- model.load_state_dict(state_dict)
74
  model.eval()
75
  return model
76
-
77
  model = load_model()
78
 
79
  # ------------------------ Descriptors ------------------------
 
68
  # ------------------------ Load Model ------------------------
69
  @st.cache_resource
70
  def load_model():
71
+ model = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
 
 
72
  model.eval()
73
  return model
 
74
  model = load_model()
75
 
76
  # ------------------------ Descriptors ------------------------