transpolymer commited on
Commit
cf36af6
·
verified ·
1 Parent(s): af4f3b0

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +27 -21
prediction.py CHANGED
@@ -9,8 +9,16 @@ from rdkit.Chem import Descriptors
9
  from rdkit.Chem import AllChem
10
  from datetime import datetime
11
  from db import get_database # This must be available in your repo
 
12
 
13
- # Load ChemBERTa model + tokenizer
 
 
 
 
 
 
 
14
  @st.cache_resource
15
  def load_chemberta():
16
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
@@ -20,7 +28,7 @@ def load_chemberta():
20
 
21
  tokenizer, chemberta = load_chemberta()
22
 
23
- # Load scalers
24
  scalers = {
25
  "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
26
  "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
@@ -30,7 +38,7 @@ scalers = {
30
  "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
31
  }
32
 
33
- # Model Definition
34
  class TransformerRegressor(nn.Module):
35
  def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
36
  super().__init__()
@@ -39,7 +47,7 @@ class TransformerRegressor(nn.Module):
39
  d_model=embedding_dim,
40
  nhead=8,
41
  dim_feedforward=ff_dim,
42
- dropout=0.1,
43
  batch_first=True
44
  )
45
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
@@ -57,7 +65,7 @@ class TransformerRegressor(nn.Module):
57
  x = x.mean(dim=1)
58
  return self.regression_head(x)
59
 
60
- # Load model
61
  @st.cache_resource
62
  def load_model():
63
  model = TransformerRegressor()
@@ -67,7 +75,7 @@ def load_model():
67
 
68
  model = load_model()
69
 
70
- # Descriptor computation
71
  def compute_descriptors(smiles: str):
72
  mol = Chem.MolFromSmiles(smiles)
73
  if mol is None:
@@ -87,22 +95,22 @@ def compute_descriptors(smiles: str):
87
  ]
88
  return np.array(descriptors, dtype=np.float32)
89
 
90
- # Fingerprint computation
91
  def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
92
  mol = Chem.MolFromSmiles(smiles)
93
  if mol is None:
94
  raise ValueError("Invalid SMILES string.")
95
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
96
- return np.array(fp, dtype=np.float32).reshape(1, -1) # (1, 1280)
97
 
98
- # Embedding function
99
  def get_chemberta_embedding(smiles: str):
100
  inputs = tokenizer(smiles, return_tensors="pt")
101
  with torch.no_grad():
102
  outputs = chemberta(**inputs)
103
- return outputs.last_hidden_state[:, 0, :] # CLS token (1, 768)
104
 
105
- # Save prediction to MongoDB
106
  def save_to_db(smiles, predictions):
107
  predictions_clean = {k: float(v) for k, v in predictions.items()}
108
  doc = {
@@ -113,7 +121,7 @@ def save_to_db(smiles, predictions):
113
  db = get_database()
114
  db["polymer_predictions"].insert_one(doc)
115
 
116
- # Main Streamlit UI + prediction
117
  def show():
118
  st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
119
  st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
@@ -128,22 +136,22 @@ def show():
128
  return
129
 
130
  descriptors = compute_descriptors(smiles_input)
131
- descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0) # (1, 10)
132
 
133
- fingerprint = get_morgan_fingerprint(smiles_input) # (1, 1280)
134
- fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32) # (1, 1280)
135
 
136
- embedding = get_chemberta_embedding(smiles_input) # (1, 768)
137
 
138
- combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1) # (1, 2058)
139
- combined = combined_input.unsqueeze(1) # (1, 1, 2058)
140
 
141
  with torch.no_grad():
142
  preds = model(combined)
143
 
144
  preds_np = preds.numpy()
145
-
146
  keys = list(scalers.keys())
 
147
  preds_rescaled = np.concatenate([
148
  scalers[keys[i]].inverse_transform(preds_np[:, [i]])
149
  for i in range(6)
@@ -151,12 +159,10 @@ def show():
151
 
152
  results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
153
 
154
- # Display results
155
  st.success("Predicted Properties:")
156
  for key, val in results.items():
157
  st.markdown(f"**{key}**: {val}")
158
 
159
- # Save to MongoDB
160
  save_to_db(smiles_input, results)
161
 
162
  except Exception as e:
 
9
  from rdkit.Chem import AllChem
10
  from datetime import datetime
11
  from db import get_database # This must be available in your repo
12
+ import random
13
 
14
+ # ------------------------ Ensuring Deterministic Behavior ------------------------
15
+ random.seed(42)
16
+ np.random.seed(42)
17
+ torch.manual_seed(42)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+ # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
22
  @st.cache_resource
23
  def load_chemberta():
24
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
 
28
 
29
  tokenizer, chemberta = load_chemberta()
30
 
31
+ # ------------------------ Load Scalers ------------------------
32
  scalers = {
33
  "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
34
  "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
 
38
  "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
39
  }
40
 
41
+ # ------------------------ Transformer Model ------------------------
42
  class TransformerRegressor(nn.Module):
43
  def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
44
  super().__init__()
 
47
  d_model=embedding_dim,
48
  nhead=8,
49
  dim_feedforward=ff_dim,
50
+ dropout=0.0, # No dropout for consistency
51
  batch_first=True
52
  )
53
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
65
  x = x.mean(dim=1)
66
  return self.regression_head(x)
67
 
68
+ # ------------------------ Load Model ------------------------
69
  @st.cache_resource
70
  def load_model():
71
  model = TransformerRegressor()
 
75
 
76
  model = load_model()
77
 
78
+ # ------------------------ Descriptors ------------------------
79
  def compute_descriptors(smiles: str):
80
  mol = Chem.MolFromSmiles(smiles)
81
  if mol is None:
 
95
  ]
96
  return np.array(descriptors, dtype=np.float32)
97
 
98
+ # ------------------------ Fingerprints ------------------------
99
  def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
100
  mol = Chem.MolFromSmiles(smiles)
101
  if mol is None:
102
  raise ValueError("Invalid SMILES string.")
103
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
104
+ return np.array(fp, dtype=np.float32).reshape(1, -1)
105
 
106
+ # ------------------------ Embedding ------------------------
107
  def get_chemberta_embedding(smiles: str):
108
  inputs = tokenizer(smiles, return_tensors="pt")
109
  with torch.no_grad():
110
  outputs = chemberta(**inputs)
111
+ return outputs.last_hidden_state.mean(dim=1) # Use average instead of CLS token
112
 
113
+ # ------------------------ Save to DB ------------------------
114
  def save_to_db(smiles, predictions):
115
  predictions_clean = {k: float(v) for k, v in predictions.items()}
116
  doc = {
 
121
  db = get_database()
122
  db["polymer_predictions"].insert_one(doc)
123
 
124
+ # ------------------------ Streamlit App ------------------------
125
  def show():
126
  st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
127
  st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
 
136
  return
137
 
138
  descriptors = compute_descriptors(smiles_input)
139
+ descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
140
 
141
+ fingerprint = get_morgan_fingerprint(smiles_input)
142
+ fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32)
143
 
144
+ embedding = get_chemberta_embedding(smiles_input)
145
 
146
+ combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
147
+ combined = combined_input.unsqueeze(1)
148
 
149
  with torch.no_grad():
150
  preds = model(combined)
151
 
152
  preds_np = preds.numpy()
 
153
  keys = list(scalers.keys())
154
+
155
  preds_rescaled = np.concatenate([
156
  scalers[keys[i]].inverse_transform(preds_np[:, [i]])
157
  for i in range(6)
 
159
 
160
  results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
161
 
 
162
  st.success("Predicted Properties:")
163
  for key, val in results.items():
164
  st.markdown(f"**{key}**: {val}")
165
 
 
166
  save_to_db(smiles_input, results)
167
 
168
  except Exception as e: