transpolymer commited on
Commit
3b0f51a
·
verified ·
1 Parent(s): ca63203

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +5 -8
prediction.py CHANGED
@@ -24,9 +24,10 @@ def load_chemberta():
24
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
25
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
  model.eval()
 
27
  return tokenizer, model
28
-
29
- tokenizer, chemberta = load_chemberta()
30
 
31
  # ------------------------ Load Scalers ------------------------
32
  scalers = {
@@ -65,7 +66,6 @@ class TransformerRegressor(nn.Module):
65
  x = x.mean(dim=1)
66
  return self.regression_head(x)
67
 
68
- # ------------------------ Load Model ------------------------
69
  # ------------------------ Load Model ------------------------
70
  @st.cache_resource
71
  def load_model():
@@ -73,18 +73,15 @@ def load_model():
73
  model = TransformerRegressor()
74
 
75
  # Load the state_dict (weights) from the saved model file
76
- state_dict = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
77
 
78
  # Load the state_dict into the model
79
  model.load_state_dict(state_dict)
80
 
81
  # Set the model to evaluation mode
82
  model.eval()
 
83
  return model
84
-
85
- # Load the model
86
- model = load_model()
87
-
88
  # ------------------------ Descriptors ------------------------
89
  def compute_descriptors(smiles: str):
90
  mol = Chem.MolFromSmiles(smiles)
 
24
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
25
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
  model.eval()
27
+ model.to(device) # Send model to GPU if available
28
  return tokenizer, model
29
+ # Check if CUDA is available for GPU acceleration
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
  # ------------------------ Load Scalers ------------------------
33
  scalers = {
 
66
  x = x.mean(dim=1)
67
  return self.regression_head(x)
68
 
 
69
  # ------------------------ Load Model ------------------------
70
  @st.cache_resource
71
  def load_model():
 
73
  model = TransformerRegressor()
74
 
75
  # Load the state_dict (weights) from the saved model file
76
+ state_dict = torch.load("transformer_model.bin", map_location=device) # Ensure loading on the correct device
77
 
78
  # Load the state_dict into the model
79
  model.load_state_dict(state_dict)
80
 
81
  # Set the model to evaluation mode
82
  model.eval()
83
+ model.to(device) # Send model to GPU if available
84
  return model
 
 
 
 
85
  # ------------------------ Descriptors ------------------------
86
  def compute_descriptors(smiles: str):
87
  mol = Chem.MolFromSmiles(smiles)