transpolymer commited on
Commit
8e77cfc
·
verified ·
1 Parent(s): c530f4f

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +8 -20
prediction.py CHANGED
@@ -8,10 +8,9 @@ from rdkit import Chem
8
  from rdkit.Chem import Descriptors
9
  from rdkit.Chem import AllChem
10
  from datetime import datetime
11
- from db import get_database # Ensure this module is available
12
  import random
13
-
14
-
15
 
16
  # ------------------------ Ensuring Deterministic Behavior ------------------------
17
  random.seed(42)
@@ -20,7 +19,6 @@ torch.manual_seed(42)
20
  torch.backends.cudnn.deterministic = True
21
  torch.backends.cudnn.benchmark = False
22
 
23
- # Check if CUDA is available for GPU acceleration
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
@@ -29,7 +27,7 @@ def load_chemberta():
29
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
30
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
31
  model.eval()
32
- model.to(device) # Send model to GPU if available
33
  return tokenizer, model
34
 
35
  # ------------------------ Load Scalers ------------------------
@@ -51,7 +49,7 @@ class TransformerRegressor(nn.Module):
51
  d_model=embedding_dim,
52
  nhead=8,
53
  dim_feedforward=ff_dim,
54
- dropout=0.0, # No dropout for consistency
55
  batch_first=True
56
  )
57
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
@@ -72,18 +70,15 @@ class TransformerRegressor(nn.Module):
72
  # ------------------------ Load Model ------------------------
73
  @st.cache_resource
74
  def load_model():
75
- # Initialize the model architecture first
76
  model = TransformerRegressor()
77
-
78
- # Load the state_dict (weights) from the saved model file
79
  try:
80
- state_dict = torch.load("transformer_model(1).bin",map_location=device) # Ensure loading on the correct device
 
81
  model.load_state_dict(state_dict)
82
  model.eval()
83
- model.to(device) # Send model to GPU if available
84
  except Exception as e:
85
  raise ValueError(f"Failed to load model: {e}")
86
-
87
  return model
88
 
89
  # ------------------------ Descriptors ------------------------
@@ -91,7 +86,6 @@ def compute_descriptors(smiles: str):
91
  mol = Chem.MolFromSmiles(smiles)
92
  if mol is None:
93
  raise ValueError("Invalid SMILES string.")
94
-
95
  descriptors = [
96
  Descriptors.MolWt(mol),
97
  Descriptors.MolLogP(mol),
@@ -119,7 +113,7 @@ def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
119
  inputs = tokenizer(smiles, return_tensors="pt")
120
  with torch.no_grad():
121
  outputs = chemberta(**inputs)
122
- return outputs.last_hidden_state.mean(dim=1) # Use average instead of CLS token
123
 
124
  # ------------------------ Save to DB ------------------------
125
  def save_to_db(smiles, predictions):
@@ -141,7 +135,6 @@ def show():
141
 
142
  if st.button("Predict"):
143
  try:
144
- # Load the model
145
  model = load_model()
146
 
147
  mol = Chem.MolFromSmiles(smiles_input)
@@ -149,10 +142,8 @@ def show():
149
  st.error("Invalid SMILES string.")
150
  return
151
 
152
- # Load the ChemBERTa tokenizer and model
153
  tokenizer, chemberta = load_chemberta()
154
 
155
- # Compute Descriptors, Fingerprints, and Embedding
156
  descriptors = compute_descriptors(smiles_input)
157
  descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
158
 
@@ -161,7 +152,6 @@ def show():
161
 
162
  embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
163
 
164
- # Combine Inputs and Make Prediction
165
  combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
166
  combined = combined_input.unsqueeze(1)
167
 
@@ -171,7 +161,6 @@ def show():
171
  preds_np = preds.numpy()
172
  keys = list(scalers.keys())
173
 
174
- # Rescale Predictions
175
  preds_rescaled = np.concatenate([
176
  scalers[keys[i]].inverse_transform(preds_np[:, [i]])
177
  for i in range(6)
@@ -183,7 +172,6 @@ def show():
183
  for key, val in results.items():
184
  st.markdown(f"**{key}**: {val}")
185
 
186
- # Save the results to the database
187
  save_to_db(smiles_input, results)
188
 
189
  except Exception as e:
 
8
  from rdkit.Chem import Descriptors
9
  from rdkit.Chem import AllChem
10
  from datetime import datetime
11
+ from db import get_database
12
  import random
13
+ import os # <-- Added for debugging file paths
 
14
 
15
  # ------------------------ Ensuring Deterministic Behavior ------------------------
16
  random.seed(42)
 
19
  torch.backends.cudnn.deterministic = True
20
  torch.backends.cudnn.benchmark = False
21
 
 
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
 
27
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
28
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
29
  model.eval()
30
+ model.to(device)
31
  return tokenizer, model
32
 
33
  # ------------------------ Load Scalers ------------------------
 
49
  d_model=embedding_dim,
50
  nhead=8,
51
  dim_feedforward=ff_dim,
52
+ dropout=0.0,
53
  batch_first=True
54
  )
55
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
70
  # ------------------------ Load Model ------------------------
71
  @st.cache_resource
72
  def load_model():
 
73
  model = TransformerRegressor()
 
 
74
  try:
75
+ print("Files in working directory:", os.listdir()) # <-- Debug print
76
+ state_dict = torch.load("transformer_model.bin", map_location=device)
77
  model.load_state_dict(state_dict)
78
  model.eval()
79
+ model.to(device)
80
  except Exception as e:
81
  raise ValueError(f"Failed to load model: {e}")
 
82
  return model
83
 
84
  # ------------------------ Descriptors ------------------------
 
86
  mol = Chem.MolFromSmiles(smiles)
87
  if mol is None:
88
  raise ValueError("Invalid SMILES string.")
 
89
  descriptors = [
90
  Descriptors.MolWt(mol),
91
  Descriptors.MolLogP(mol),
 
113
  inputs = tokenizer(smiles, return_tensors="pt")
114
  with torch.no_grad():
115
  outputs = chemberta(**inputs)
116
+ return outputs.last_hidden_state.mean(dim=1)
117
 
118
  # ------------------------ Save to DB ------------------------
119
  def save_to_db(smiles, predictions):
 
135
 
136
  if st.button("Predict"):
137
  try:
 
138
  model = load_model()
139
 
140
  mol = Chem.MolFromSmiles(smiles_input)
 
142
  st.error("Invalid SMILES string.")
143
  return
144
 
 
145
  tokenizer, chemberta = load_chemberta()
146
 
 
147
  descriptors = compute_descriptors(smiles_input)
148
  descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
149
 
 
152
 
153
  embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
154
 
 
155
  combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
156
  combined = combined_input.unsqueeze(1)
157
 
 
161
  preds_np = preds.numpy()
162
  keys = list(scalers.keys())
163
 
 
164
  preds_rescaled = np.concatenate([
165
  scalers[keys[i]].inverse_transform(preds_np[:, [i]])
166
  for i in range(6)
 
172
  for key, val in results.items():
173
  st.markdown(f"**{key}**: {val}")
174
 
 
175
  save_to_db(smiles_input, results)
176
 
177
  except Exception as e: