transpolymer commited on
Commit
c621eb3
·
verified ·
1 Parent(s): 33dbc5c

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +81 -126
prediction.py CHANGED
@@ -1,133 +1,88 @@
1
  import streamlit as st
2
  import torch
 
 
3
  import numpy as np
4
- from transformers import AutoTokenizer, AutoModel
5
  from rdkit import Chem
6
- from rdkit.Chem import AllChem, Descriptors
7
- from torch import nn
8
- from datetime import datetime
9
- from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB
10
-
11
- # Load tokenizer and ChemBERTa model
12
- @st.cache_resource
13
- def load_chemberta():
14
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
15
- model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
16
- model.eval()
17
- return tokenizer, model
18
-
19
- tokenizer, chemberta = load_chemberta()
20
-
21
- # Define your model architecture
22
- class TransformerRegressor(nn.Module):
23
- def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
24
- super().__init__()
25
- self.feat_proj = nn.Linear(feat_dim, emb_dim)
26
- encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
27
- self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
28
- self.regression_head = nn.Sequential(
29
- nn.Linear(emb_dim, 256), nn.ReLU(),
30
- nn.Linear(256, 128), nn.ReLU(),
31
- nn.Linear(128, output_dim)
32
- )
33
-
34
- def forward(self, x, feat):
35
- feat_emb = self.feat_proj(feat)
36
- stacked = torch.stack([x, feat_emb], dim=1)
37
- encoded = self.transformer_encoder(stacked)
38
- aggregated = encoded.mean(dim=1)
39
- return self.regression_head(aggregated)
40
-
41
- # Load your saved model
42
- @st.cache_resource
43
- def load_regression_model():
44
- model = TransformerRegressor()
45
- state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu"))
46
- model.load_state_dict(state_dict)
47
- model.eval()
48
- return model
49
-
50
- model = load_regression_model()
51
-
52
- # Feature Functions
53
- descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
54
- Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
55
- Descriptors.NumHDonors, Descriptors.RingCount,
56
- Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
57
- Descriptors.NHOHCount]
58
-
59
- def fix_smiles(s):
60
- try:
61
- mol = Chem.MolFromSmiles(s.strip())
62
- if mol:
63
- return Chem.MolToSmiles(mol)
64
- except:
65
- return None
66
- return None
67
 
68
- def compute_features(smiles):
 
 
 
 
 
 
69
  mol = Chem.MolFromSmiles(smiles)
70
- if not mol:
71
- return [0]*10 + [0]*2048
72
- desc = [fn(mol) for fn in descriptor_fns]
73
- fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
74
- return desc + list(fp)
75
-
76
- def embed_smiles(smiles_list):
77
- inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
78
- outputs = chemberta(**inputs)
79
- return outputs.last_hidden_state[:, 0, :]
80
-
81
- # Function to save prediction to MongoDB
82
- def save_to_db(smiles, predictions):
83
- # Convert all prediction values to native Python float
84
- predictions_clean = {k: float(v) for k, v in predictions.items()}
85
-
86
- doc = {
87
- "smiles": smiles,
88
- "predictions": predictions_clean,
89
- "timestamp": datetime.now()
90
- }
91
-
92
- db = get_database() # Connect to MongoDB
93
- collection = db["polymer_predictions"]
94
- collection.insert_one(doc)
95
-
96
- # Prediction Page UI
97
- def show():
98
- st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
99
- st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
100
-
101
- smiles_input = st.text_input("Enter SMILES Representation of Polymer")
102
-
103
- if st.button("Predict"):
104
- fixed = fix_smiles(smiles_input)
105
- if not fixed:
106
- st.error("Invalid SMILES string.")
107
  else:
108
- features = compute_features(fixed)
109
- features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
110
- embedding = embed_smiles([fixed])
111
-
112
- with torch.no_grad():
113
- pred = model(embedding, features_tensor)
114
- result = pred.numpy().flatten()
115
-
116
- properties = [
117
- "Tensile Strength",
118
- "Ionization Energy",
119
- "Electron Affinity",
120
- "logP",
121
- "Refractive Index",
122
- "Molecular Weight"
123
- ]
124
-
125
- predictions = {}
126
- st.success("Predicted Polymer Properties:")
127
- for prop, val in zip(properties, result):
128
- st.write(f"**{prop}**: {val:.4f}")
129
- predictions[prop] = val
130
-
131
- # Save the prediction to MongoDB
132
- save_to_db(smiles_input, predictions)
133
- st.success("Prediction saved successfully!")
 
1
  import streamlit as st
2
  import torch
3
+ import joblib
4
+ import pandas as pd
5
  import numpy as np
 
6
  from rdkit import Chem
7
+ from rdkit.Chem import AllChem
8
+ from transformers import AutoTokenizer, AutoModel
9
+ import os
10
+
11
+ # Load ChemBERTa model and tokenizer
12
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
+ chemberta_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ chemberta_model.to(device)
17
+ chemberta_model.eval()
18
+
19
+ # Load models
20
+ model_dir = "saved_model"
21
+ model_paths = [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith(".pkl") and "scaler" not in f]
22
+ models = [joblib.load(p) for p in model_paths]
23
+
24
+ # Load input and target scalers
25
+ input_scaler_path = os.path.join(model_dir, "scaler.pkl")
26
+ input_scaler = joblib.load(input_scaler_path) if os.path.exists(input_scaler_path) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ target_scaler_path = os.path.join(model_dir, "target_scaler.pkl")
29
+ target_scaler = joblib.load(target_scaler_path) if os.path.exists(target_scaler_path) else None
30
+
31
+ # Properties
32
+ PROPERTY_NAMES = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
33
+
34
+ def smiles_to_fingerprint(smiles, radius=2, nBits=2048):
35
  mol = Chem.MolFromSmiles(smiles)
36
+ if mol is None:
37
+ return None
38
+ return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits))
39
+
40
+ def smiles_to_chemberta_embedding(smiles):
41
+ inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
42
+ with torch.no_grad():
43
+ outputs = chemberta_model(**inputs)
44
+ return outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
45
+
46
+ def create_features(smiles):
47
+ fp = smiles_to_fingerprint(smiles)
48
+ if fp is None:
49
+ return None
50
+ emb = smiles_to_chemberta_embedding(smiles)
51
+ return np.concatenate([fp, emb])
52
+
53
+ # Streamlit UI
54
+ st.title("TransPolymer Property Predictor")
55
+ user_input = st.text_input("Enter SMILES:")
56
+
57
+ if st.button("Predict"):
58
+ if not user_input.strip():
59
+ st.error("Please enter a valid SMILES string.")
60
+ else:
61
+ features = create_features(user_input)
62
+ if features is None:
63
+ st.error("Invalid SMILES format.")
 
 
 
 
 
 
 
 
 
64
  else:
65
+ if input_scaler:
66
+ features = input_scaler.transform([features])
67
+ else:
68
+ features = [features]
69
+
70
+ raw_preds = np.mean([model.predict(features) for model in models], axis=0).flatten()
71
+
72
+ if target_scaler:
73
+ predictions = target_scaler.inverse_transform([raw_preds])[0]
74
+ else:
75
+ predictions = raw_preds
76
+
77
+ result_df = pd.DataFrame([predictions], columns=PROPERTY_NAMES)
78
+ result_df.insert(0, "SMILES", user_input)
79
+
80
+ st.success("Predicted Properties:")
81
+ st.dataframe(result_df.style.format(precision=4))
82
+
83
+ # Optional: save to CSV
84
+ history_path = "prediction_history.csv"
85
+ if os.path.exists(history_path):
86
+ existing = pd.read_csv(history_path)
87
+ result_df = pd.concat([existing, result_df], ignore_index=True)
88
+ result_df.to_csv(history_path, index=False)