transpolymer commited on
Commit
4b283df
·
verified ·
1 Parent(s): c748d73

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +108 -86
prediction.py CHANGED
@@ -1,16 +1,25 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import joblib
4
  import numpy as np
 
 
5
  from rdkit import Chem
6
  from rdkit.Chem import Descriptors
7
- from transformers import AutoTokenizer, AutoModel
 
8
 
9
- # Load ChemBERTa tokenizer and model
10
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
- embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
 
 
 
 
12
 
13
- # Load saved scalers for inverse transformations
 
 
14
  scalers = {
15
  "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
16
  "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
@@ -20,29 +29,9 @@ scalers = {
20
  "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
21
  }
22
 
23
- # Descriptor calculation
24
- def compute_descriptors(smiles: str):
25
- mol = Chem.MolFromSmiles(smiles)
26
- if mol is None:
27
- raise ValueError("Invalid SMILES string.")
28
-
29
- descriptors = [
30
- Descriptors.MolWt(mol),
31
- Descriptors.MolLogP(mol),
32
- Descriptors.TPSA(mol),
33
- Descriptors.NumRotatableBonds(mol),
34
- Descriptors.NumHDonors(mol),
35
- Descriptors.NumHAcceptors(mol),
36
- Descriptors.FractionCSP3(mol),
37
- Descriptors.HeavyAtomCount(mol),
38
- Descriptors.RingCount(mol),
39
- Descriptors.MolMR(mol)
40
- ]
41
- return np.array(descriptors, dtype=np.float32)
42
-
43
- # Transformer regression model definition (must match training)
44
  class TransformerRegressor(nn.Module):
45
- def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim):
46
  super().__init__()
47
  self.feat_proj = nn.Linear(input_dim, embedding_dim)
48
  encoder_layer = nn.TransformerEncoderLayer(
@@ -67,62 +56,95 @@ class TransformerRegressor(nn.Module):
67
  x = x.mean(dim=1)
68
  return self.regression_head(x)
69
 
70
- # Model hyperparameters (must match training)
71
- embedding_dim = 768
72
- descriptor_dim = 1290 # Based on earlier errors. If unsure, use 1290
73
- input_dim = embedding_dim + descriptor_dim # 768 + 1290 = 2058
74
- ff_dim = 1024
75
- num_layers = 2
76
- output_dim = 6
77
-
78
- # Load trained model
79
- device = torch.device("cpu")
80
- model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim)
81
- model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
82
- model.eval()
83
-
84
- # Prediction function
85
- def predict_properties(smiles: str):
86
- try:
87
- # Compute descriptors
88
- descriptors = compute_descriptors(smiles)
89
- descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
90
-
91
- # Get ChemBERTa embedding (CLS token)
92
- inputs = tokenizer(smiles, return_tensors="pt")
93
- with torch.no_grad():
94
- outputs = embedding_model(**inputs)
95
- embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
96
-
97
- # Combine features
98
- combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # Shape: (1, 1, 2058)
99
-
100
- # Forward pass
101
- with torch.no_grad():
102
- preds = model(combined)
103
-
104
- preds_np = preds.numpy()
105
-
106
- # Inverse transform each property
107
- keys = list(scalers.keys())
108
- preds_rescaled = np.concatenate([
109
- scalers[keys[i]].inverse_transform(preds_np[:, [i]])
110
- for i in range(output_dim)
111
- ], axis=1)
112
-
113
- results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
114
- return results
115
-
116
- except Exception as e:
117
- return {"error": str(e)}
118
-
119
- # Show function to print the results
120
- def show(smiles: str):
121
- result = predict_properties(smiles)
122
 
123
- if "error" in result:
124
- print(f"Error: {result['error']}")
125
- else:
126
- print("Predicted Properties for SMILES:", smiles)
127
- for key, value in result.items():
128
- print(f"{key}: {value}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  import torch.nn as nn
 
4
  import numpy as np
5
+ import joblib
6
+ from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
  from rdkit.Chem import Descriptors
9
+ from datetime import datetime
10
+ from db import get_database # This must be available in your repo
11
 
12
+ # Load ChemBERTa model + tokenizer
13
+ @st.cache_resource
14
+ def load_chemberta():
15
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
16
+ model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
17
+ model.eval()
18
+ return tokenizer, model
19
 
20
+ tokenizer, chemberta = load_chemberta()
21
+
22
+ # Load scalers
23
  scalers = {
24
  "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
25
  "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
 
29
  "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
30
  }
31
 
32
+ # Model Definition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class TransformerRegressor(nn.Module):
34
+ def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
35
  super().__init__()
36
  self.feat_proj = nn.Linear(input_dim, embedding_dim)
37
  encoder_layer = nn.TransformerEncoderLayer(
 
56
  x = x.mean(dim=1)
57
  return self.regression_head(x)
58
 
59
+ # Load model
60
+ @st.cache_resource
61
+ def load_model():
62
+ model = TransformerRegressor()
63
+ model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
64
+ model.eval()
65
+ return model
66
+
67
+ model = load_model()
68
+
69
+ # Descriptor computation
70
+ def compute_descriptors(smiles: str):
71
+ mol = Chem.MolFromSmiles(smiles)
72
+ if mol is None:
73
+ raise ValueError("Invalid SMILES string.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ descriptors = [
76
+ Descriptors.MolWt(mol),
77
+ Descriptors.MolLogP(mol),
78
+ Descriptors.TPSA(mol),
79
+ Descriptors.NumRotatableBonds(mol),
80
+ Descriptors.NumHDonors(mol),
81
+ Descriptors.NumHAcceptors(mol),
82
+ Descriptors.FractionCSP3(mol),
83
+ Descriptors.HeavyAtomCount(mol),
84
+ Descriptors.RingCount(mol),
85
+ Descriptors.MolMR(mol)
86
+ ]
87
+ return np.array(descriptors, dtype=np.float32)
88
+
89
+ # Embedding function
90
+ def get_chemberta_embedding(smiles: str):
91
+ inputs = tokenizer(smiles, return_tensors="pt")
92
+ with torch.no_grad():
93
+ outputs = chemberta(**inputs)
94
+ return outputs.last_hidden_state[:, 0, :] # CLS token
95
+
96
+ # Save prediction to MongoDB
97
+ def save_to_db(smiles, predictions):
98
+ predictions_clean = {k: float(v) for k, v in predictions.items()}
99
+ doc = {
100
+ "smiles": smiles,
101
+ "predictions": predictions_clean,
102
+ "timestamp": datetime.now()
103
+ }
104
+ db = get_database()
105
+ db["polymer_predictions"].insert_one(doc)
106
+
107
+ # Main Streamlit UI + prediction
108
+ def show():
109
+ st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
110
+ st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
111
+
112
+ smiles_input = st.text_input("Enter SMILES Representation of Polymer")
113
+
114
+ if st.button("Predict"):
115
+ try:
116
+ mol = Chem.MolFromSmiles(smiles_input)
117
+ if mol is None:
118
+ st.error("Invalid SMILES string.")
119
+ return
120
+
121
+ descriptors = compute_descriptors(smiles_input)
122
+ descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
123
+
124
+ embedding = get_chemberta_embedding(smiles_input)
125
+
126
+ combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # (1, 1, 2058)
127
+
128
+ with torch.no_grad():
129
+ preds = model(combined)
130
+
131
+ preds_np = preds.numpy()
132
+
133
+ keys = list(scalers.keys())
134
+ preds_rescaled = np.concatenate([
135
+ scalers[keys[i]].inverse_transform(preds_np[:, [i]])
136
+ for i in range(6)
137
+ ], axis=1)
138
+
139
+ results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
140
+
141
+ # Display results
142
+ st.success("Predicted Properties:")
143
+ for key, val in results.items():
144
+ st.markdown(f"**{key}**: {val}")
145
+
146
+ # Save to MongoDB
147
+ save_to_db(smiles_input, results)
148
+
149
+ except Exception as e:
150
+ st.error(f"Prediction failed: {e}")