transpolymer commited on
Commit
a08c2a4
·
verified ·
1 Parent(s): e05cbe5

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +16 -23
prediction.py CHANGED
@@ -20,15 +20,13 @@ torch.backends.cudnn.benchmark = False
20
  # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
23
- # # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
24
- # @st.cache_resource
25
- # def load_chemberta():
26
- # tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
- # model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
28
- # model.eval()
29
- # model.to(device)
30
- # return tokenizer, model
31
- # tokenizer, chemberta = load_chemberta()
32
 
33
 
34
  # ------------------------ Load Scalers ------------------------
@@ -67,27 +65,22 @@ class TransformerRegressor(nn.Module):
67
  x = self.transformer_encoder(x)
68
  x = x.mean(dim=1)
69
  return self.regression_head(x)
70
- @st.cache_resource
71
- def load_chemberta():
72
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
73
- model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
74
- model.eval()
75
- model.to(device)
76
- return tokenizer, model
77
 
78
  @st.cache_resource
79
  def load_model():
 
80
  model = TransformerRegressor()
81
- state_dict = torch.load("transformer_model(1).bin", map_location=device)
 
 
 
 
82
  model.load_state_dict(state_dict)
 
 
83
  model.eval()
84
- model.to(device)
85
  return model
86
-
87
- # Call them to load the actual models
88
- tokenizer, chemberta = load_chemberta()
89
- model = load_model()
90
-
91
  # ------------------------ Descriptors ------------------------
92
  def compute_descriptors(smiles: str):
93
  mol = Chem.MolFromSmiles(smiles)
 
20
  # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
23
+ @st.cache_resource
24
+ def load_chemberta():
25
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
+ model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
+ model.eval()
28
+ model.to(device) # Send model to GPU if available
29
+ return tokenizer, model
 
 
30
 
31
 
32
  # ------------------------ Load Scalers ------------------------
 
65
  x = self.transformer_encoder(x)
66
  x = x.mean(dim=1)
67
  return self.regression_head(x)
 
 
 
 
 
 
 
68
 
69
  @st.cache_resource
70
  def load_model():
71
+ # Initialize the model architecture first
72
  model = TransformerRegressor()
73
+
74
+ # Load the state_dict (weights) from the saved model file
75
+ state_dict = torch.load("transformer_model(1).bin", map_location=device) # Ensure loading on the correct device
76
+
77
+ # Load the state_dict into the model
78
  model.load_state_dict(state_dict)
79
+
80
+ # Set the model to evaluation mode
81
  model.eval()
82
+ model.to(device) # Send model to GPU if available
83
  return model
 
 
 
 
 
84
  # ------------------------ Descriptors ------------------------
85
  def compute_descriptors(smiles: str):
86
  mol = Chem.MolFromSmiles(smiles)