transpolymer commited on
Commit
ca63203
·
verified ·
1 Parent(s): cf6ea36

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +13 -1
prediction.py CHANGED
@@ -65,12 +65,24 @@ class TransformerRegressor(nn.Module):
65
  x = x.mean(dim=1)
66
  return self.regression_head(x)
67
 
 
68
  # ------------------------ Load Model ------------------------
69
  @st.cache_resource
70
  def load_model():
71
- model = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
 
 
 
 
 
 
 
 
 
72
  model.eval()
73
  return model
 
 
74
  model = load_model()
75
 
76
  # ------------------------ Descriptors ------------------------
 
65
  x = x.mean(dim=1)
66
  return self.regression_head(x)
67
 
68
+ # ------------------------ Load Model ------------------------
69
  # ------------------------ Load Model ------------------------
70
  @st.cache_resource
71
  def load_model():
72
+ # Initialize the model architecture first
73
+ model = TransformerRegressor()
74
+
75
+ # Load the state_dict (weights) from the saved model file
76
+ state_dict = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
77
+
78
+ # Load the state_dict into the model
79
+ model.load_state_dict(state_dict)
80
+
81
+ # Set the model to evaluation mode
82
  model.eval()
83
  return model
84
+
85
+ # Load the model
86
  model = load_model()
87
 
88
  # ------------------------ Descriptors ------------------------