transpolymer commited on
Commit
5e9e549
·
verified ·
1 Parent(s): c621eb3

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +89 -80
prediction.py CHANGED
@@ -1,88 +1,97 @@
1
- import streamlit as st
2
  import torch
 
3
  import joblib
4
- import pandas as pd
5
- import numpy as np
6
  from rdkit import Chem
7
- from rdkit.Chem import AllChem
8
  from transformers import AutoTokenizer, AutoModel
9
- import os
10
 
11
- # Load ChemBERTa model and tokenizer
12
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
- chemberta_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
14
-
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- chemberta_model.to(device)
17
- chemberta_model.eval()
18
-
19
- # Load models
20
- model_dir = "saved_model"
21
- model_paths = [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith(".pkl") and "scaler" not in f]
22
- models = [joblib.load(p) for p in model_paths]
23
-
24
- # Load input and target scalers
25
- input_scaler_path = os.path.join(model_dir, "scaler.pkl")
26
- input_scaler = joblib.load(input_scaler_path) if os.path.exists(input_scaler_path) else None
27
-
28
- target_scaler_path = os.path.join(model_dir, "target_scaler.pkl")
29
- target_scaler = joblib.load(target_scaler_path) if os.path.exists(target_scaler_path) else None
30
-
31
- # Properties
32
- PROPERTY_NAMES = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
33
-
34
- def smiles_to_fingerprint(smiles, radius=2, nBits=2048):
35
  mol = Chem.MolFromSmiles(smiles)
36
  if mol is None:
37
- return None
38
- return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits))
39
-
40
- def smiles_to_chemberta_embedding(smiles):
41
- inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
42
- with torch.no_grad():
43
- outputs = chemberta_model(**inputs)
44
- return outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
45
-
46
- def create_features(smiles):
47
- fp = smiles_to_fingerprint(smiles)
48
- if fp is None:
49
- return None
50
- emb = smiles_to_chemberta_embedding(smiles)
51
- return np.concatenate([fp, emb])
52
-
53
- # Streamlit UI
54
- st.title("TransPolymer Property Predictor")
55
- user_input = st.text_input("Enter SMILES:")
56
-
57
- if st.button("Predict"):
58
- if not user_input.strip():
59
- st.error("Please enter a valid SMILES string.")
60
- else:
61
- features = create_features(user_input)
62
- if features is None:
63
- st.error("Invalid SMILES format.")
64
- else:
65
- if input_scaler:
66
- features = input_scaler.transform([features])
67
- else:
68
- features = [features]
69
-
70
- raw_preds = np.mean([model.predict(features) for model in models], axis=0).flatten()
71
-
72
- if target_scaler:
73
- predictions = target_scaler.inverse_transform([raw_preds])[0]
74
- else:
75
- predictions = raw_preds
76
-
77
- result_df = pd.DataFrame([predictions], columns=PROPERTY_NAMES)
78
- result_df.insert(0, "SMILES", user_input)
79
-
80
- st.success("Predicted Properties:")
81
- st.dataframe(result_df.style.format(precision=4))
82
-
83
- # Optional: save to CSV
84
- history_path = "prediction_history.csv"
85
- if os.path.exists(history_path):
86
- existing = pd.read_csv(history_path)
87
- result_df = pd.concat([existing, result_df], ignore_index=True)
88
- result_df.to_csv(history_path, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
  import joblib
 
 
4
  from rdkit import Chem
5
+ from rdkit.Chem import Descriptors
6
  from transformers import AutoTokenizer, AutoModel
7
+ import numpy as np
8
 
9
+ # Load tokenizer and embedding model
10
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
+ embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
+
13
+ # Load individual scalers
14
+ target_keys = [
15
+ "Tensile_strength(Mpa)",
16
+ "Ionization_Energy(eV)",
17
+ "Electron_Affinity(eV)",
18
+ "LogP",
19
+ "Refractive_Index",
20
+ "Molecular_Weight(g/mol)"
21
+ ]
22
+ scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]
23
+
24
+ # Descriptor function (must match training order)
25
+ def compute_descriptors(smiles):
 
 
 
 
 
 
 
26
  mol = Chem.MolFromSmiles(smiles)
27
  if mol is None:
28
+ raise ValueError("Invalid SMILES string.")
29
+ return np.array([
30
+ Descriptors.MolWt(mol),
31
+ Descriptors.MolLogP(mol),
32
+ Descriptors.TPSA(mol),
33
+ Descriptors.NumRotatableBonds(mol),
34
+ Descriptors.NumHDonors(mol),
35
+ Descriptors.NumHAcceptors(mol),
36
+ Descriptors.FractionCSP3(mol),
37
+ Descriptors.HeavyAtomCount(mol),
38
+ Descriptors.RingCount(mol),
39
+ Descriptors.MolMR(mol)
40
+ ], dtype=np.float32)
41
+
42
+ # Model class must match training
43
+ class TransformerRegressor(nn.Module):
44
+ def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
45
+ super().__init__()
46
+ self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
47
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
48
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
49
+ self.regressor = nn.Sequential(
50
+ nn.Flatten(),
51
+ nn.Linear(2 * d_model, 256),
52
+ nn.ReLU(),
53
+ nn.Linear(256, num_targets)
54
+ )
55
+
56
+ def forward(self, embedding, descriptors):
57
+ desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
58
+ stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
59
+ encoded = self.transformer(stacked) # (B, 2, d_model)
60
+ return self.regressor(encoded)
61
+
62
+ # Load trained model
63
+ model = TransformerRegressor()
64
+ model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
65
+ model.eval()
66
+
67
+ # Main prediction function
68
+ def predict_properties(smiles):
69
+ try:
70
+ # Compute descriptors
71
+ descriptors = compute_descriptors(smiles)
72
+ descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
73
+
74
+ # Get embedding from ChemBERTa
75
+ inputs = tokenizer(smiles, return_tensors="pt")
76
+ with torch.no_grad():
77
+ outputs = embedding_model(**inputs)
78
+ embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
79
+
80
+ # Predict
81
+ with torch.no_grad():
82
+ preds = model(embedding, descriptors_tensor)
83
+
84
+ # Inverse transform each prediction
85
+ preds_np = preds.numpy().flatten()
86
+ preds_rescaled = [
87
+ scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers))
88
+ ]
89
+
90
+ # Prepare results
91
+ readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
92
+ results = dict(zip(readable_keys, np.round(preds_rescaled, 4)))
93
+
94
+ return results
95
+
96
+ except Exception as e:
97
+ return {"error": str(e)}