transpolymer commited on
Commit
799aa56
·
verified ·
1 Parent(s): c54640c

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +6 -9
prediction.py CHANGED
@@ -70,19 +70,16 @@ class TransformerRegressor(nn.Module):
70
 
71
  @st.cache_resource
72
  def load_model():
73
- # Initialize the model architecture first
74
  model = TransformerRegressor()
75
-
76
- # Load the state_dict (weights) from the saved model file
77
- state_dict = torch.load("transformer_model(1).bin", map_location=device) # Ensure loading on the correct device
78
-
79
- # Load the state_dict into the model
80
  model.load_state_dict(state_dict)
81
-
82
- # Set the model to evaluation mode
83
  model.eval()
84
- model.to(device) # Send model to GPU if available
85
  return model
 
 
 
 
86
  # ------------------------ Descriptors ------------------------
87
  def compute_descriptors(smiles: str):
88
  mol = Chem.MolFromSmiles(smiles)
 
70
 
71
  @st.cache_resource
72
  def load_model():
 
73
  model = TransformerRegressor()
74
+ state_dict = torch.load("transformer_model(1).bin", map_location=device)
 
 
 
 
75
  model.load_state_dict(state_dict)
 
 
76
  model.eval()
77
+ model.to(device)
78
  return model
79
+
80
+ # Call them to load the actual models
81
+ tokenizer, chemberta = load_chemberta()
82
+ model = load_model()
83
  # ------------------------ Descriptors ------------------------
84
  def compute_descriptors(smiles: str):
85
  mol = Chem.MolFromSmiles(smiles)