transpolymer commited on
Commit
6337dea
·
verified ·
1 Parent(s): c1b893d

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +2 -1
prediction.py CHANGED
@@ -69,7 +69,8 @@ class TransformerRegressor(nn.Module):
69
  @st.cache_resource
70
  def load_model():
71
  model = TransformerRegressor()
72
- model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
 
73
  model.eval()
74
  return model
75
 
 
69
  @st.cache_resource
70
  def load_model():
71
  model = TransformerRegressor()
72
+ state_dict = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
73
+ model.load_state_dict(state_dict)
74
  model.eval()
75
  return model
76