transpolymer commited on
Commit
77a4bbb
·
verified ·
1 Parent(s): a08c2a4

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +36 -46
prediction.py CHANGED
@@ -5,10 +5,9 @@ import numpy as np
5
  import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
- from rdkit.Chem import Descriptors
9
- from rdkit.Chem import AllChem
10
  from datetime import datetime
11
- from db import get_database # This must be available in your repo
12
  import random
13
 
14
  # ------------------------ Ensuring Deterministic Behavior ------------------------
@@ -17,38 +16,25 @@ np.random.seed(42)
17
  torch.manual_seed(42)
18
  torch.backends.cudnn.deterministic = True
19
  torch.backends.cudnn.benchmark = False
20
- # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
23
  @st.cache_resource
24
  def load_chemberta():
25
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
  model.eval()
28
- model.to(device) # Send model to GPU if available
29
  return tokenizer, model
30
 
31
-
32
- # ------------------------ Load Scalers ------------------------
33
- scalers = {
34
- "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
35
- "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
36
- "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
37
- "logP": joblib.load("scaler_LogP.joblib"),
38
- "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
39
- "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
40
- }
41
-
42
- # ------------------------ Transformer Model ------------------------
43
  class TransformerRegressor(nn.Module):
44
  def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
45
  super().__init__()
46
  self.feat_proj = nn.Linear(input_dim, embedding_dim)
47
  encoder_layer = nn.TransformerEncoderLayer(
48
- d_model=embedding_dim,
49
- nhead=8,
50
- dim_feedforward=ff_dim,
51
- dropout=0.0, # No dropout for consistency
52
  batch_first=True
53
  )
54
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
@@ -68,26 +54,33 @@ class TransformerRegressor(nn.Module):
68
 
69
  @st.cache_resource
70
  def load_model():
71
- # Initialize the model architecture first
72
  model = TransformerRegressor()
73
-
74
- # Load the state_dict (weights) from the saved model file
75
- state_dict = torch.load("transformer_model(1).bin", map_location=device) # Ensure loading on the correct device
76
-
77
- # Load the state_dict into the model
78
  model.load_state_dict(state_dict)
79
-
80
- # Set the model to evaluation mode
81
  model.eval()
82
- model.to(device) # Send model to GPU if available
83
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # ------------------------ Descriptors ------------------------
85
  def compute_descriptors(smiles: str):
86
  mol = Chem.MolFromSmiles(smiles)
87
  if mol is None:
88
  raise ValueError("Invalid SMILES string.")
89
-
90
- descriptors = [
91
  Descriptors.MolWt(mol),
92
  Descriptors.MolLogP(mol),
93
  Descriptors.TPSA(mol),
@@ -98,8 +91,7 @@ def compute_descriptors(smiles: str):
98
  Descriptors.HeavyAtomCount(mol),
99
  Descriptors.RingCount(mol),
100
  Descriptors.MolMR(mol)
101
- ]
102
- return np.array(descriptors, dtype=np.float32)
103
 
104
  # ------------------------ Fingerprints ------------------------
105
  def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
@@ -111,23 +103,22 @@ def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
111
 
112
  # ------------------------ Embedding ------------------------
113
  def get_chemberta_embedding(smiles: str):
114
- inputs = tokenizer(smiles, return_tensors="pt")
115
  with torch.no_grad():
116
  outputs = chemberta(**inputs)
117
- return outputs.last_hidden_state.mean(dim=1) # Use average instead of CLS token
118
 
119
  # ------------------------ Save to DB ------------------------
120
  def save_to_db(smiles, predictions):
121
- predictions_clean = {k: float(v) for k, v in predictions.items()}
122
  doc = {
123
  "smiles": smiles,
124
- "predictions": predictions_clean,
125
  "timestamp": datetime.now()
126
  }
127
  db = get_database()
128
  db["polymer_predictions"].insert_one(doc)
129
 
130
- # ------------------------ Streamlit App ------------------------
131
  def show():
132
  st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
133
  st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
@@ -150,17 +141,14 @@ def show():
150
  embedding = get_chemberta_embedding(smiles_input)
151
 
152
  combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
153
- combined = combined_input.unsqueeze(1)
154
 
155
  with torch.no_grad():
156
- preds = model(combined)
157
 
158
- preds_np = preds.numpy()
159
  keys = list(scalers.keys())
160
-
161
  preds_rescaled = np.concatenate([
162
- scalers[keys[i]].inverse_transform(preds_np[:, [i]])
163
- for i in range(6)
164
  ], axis=1)
165
 
166
  results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
@@ -172,4 +160,6 @@ def show():
172
  save_to_db(smiles_input, results)
173
 
174
  except Exception as e:
175
- st.error(f"Prediction failed: {e}")
 
 
 
5
  import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
+ from rdkit.Chem import Descriptors, AllChem
 
9
  from datetime import datetime
10
+ from db import get_database
11
  import random
12
 
13
  # ------------------------ Ensuring Deterministic Behavior ------------------------
 
16
  torch.manual_seed(42)
17
  torch.backends.cudnn.deterministic = True
18
  torch.backends.cudnn.benchmark = False
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
22
  @st.cache_resource
23
  def load_chemberta():
24
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
25
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
  model.eval()
27
+ model.to(device)
28
  return tokenizer, model
29
 
30
+ # ------------------------ Load Transformer Model ------------------------
 
 
 
 
 
 
 
 
 
 
 
31
  class TransformerRegressor(nn.Module):
32
  def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
33
  super().__init__()
34
  self.feat_proj = nn.Linear(input_dim, embedding_dim)
35
  encoder_layer = nn.TransformerEncoderLayer(
36
+ d_model=embedding_dim, nhead=8,
37
+ dim_feedforward=ff_dim, dropout=0.0,
 
 
38
  batch_first=True
39
  )
40
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
54
 
55
  @st.cache_resource
56
  def load_model():
 
57
  model = TransformerRegressor()
58
+ state_dict = torch.load("transformer_model(1).bin", map_location=device)
 
 
 
 
59
  model.load_state_dict(state_dict)
 
 
60
  model.eval()
61
+ model.to(device)
62
  return model
63
+
64
+ # ✅ Load tokenizer/model globally
65
+ tokenizer, chemberta = load_chemberta()
66
+ model = load_model()
67
+
68
+ # ------------------------ Load Scalers ------------------------
69
+ scalers = {
70
+ "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
71
+ "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
72
+ "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
73
+ "logP": joblib.load("scaler_LogP.joblib"),
74
+ "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
75
+ "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
76
+ }
77
+
78
  # ------------------------ Descriptors ------------------------
79
  def compute_descriptors(smiles: str):
80
  mol = Chem.MolFromSmiles(smiles)
81
  if mol is None:
82
  raise ValueError("Invalid SMILES string.")
83
+ return np.array([
 
84
  Descriptors.MolWt(mol),
85
  Descriptors.MolLogP(mol),
86
  Descriptors.TPSA(mol),
 
91
  Descriptors.HeavyAtomCount(mol),
92
  Descriptors.RingCount(mol),
93
  Descriptors.MolMR(mol)
94
+ ], dtype=np.float32)
 
95
 
96
  # ------------------------ Fingerprints ------------------------
97
  def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
 
103
 
104
  # ------------------------ Embedding ------------------------
105
  def get_chemberta_embedding(smiles: str):
106
+ inputs = tokenizer(smiles, return_tensors="pt").to(device)
107
  with torch.no_grad():
108
  outputs = chemberta(**inputs)
109
+ return outputs.last_hidden_state.mean(dim=1).cpu()
110
 
111
  # ------------------------ Save to DB ------------------------
112
  def save_to_db(smiles, predictions):
 
113
  doc = {
114
  "smiles": smiles,
115
+ "predictions": {k: float(v) for k, v in predictions.items()},
116
  "timestamp": datetime.now()
117
  }
118
  db = get_database()
119
  db["polymer_predictions"].insert_one(doc)
120
 
121
+ # ------------------------ Streamlit UI ------------------------
122
  def show():
123
  st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
124
  st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
 
141
  embedding = get_chemberta_embedding(smiles_input)
142
 
143
  combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
144
+ combined = combined_input.unsqueeze(1).to(device)
145
 
146
  with torch.no_grad():
147
+ preds = model(combined).cpu().numpy()
148
 
 
149
  keys = list(scalers.keys())
 
150
  preds_rescaled = np.concatenate([
151
+ scalers[key].inverse_transform(preds[:, [i]]) for i, key in enumerate(keys)
 
152
  ], axis=1)
153
 
154
  results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
 
160
  save_to_db(smiles_input, results)
161
 
162
  except Exception as e:
163
+ st.error(f"Prediction failed: {e}")
164
+
165
+