Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +41 -39
prediction.py
CHANGED
@@ -1,30 +1,32 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import joblib
|
|
|
4 |
from rdkit import Chem
|
5 |
from rdkit.Chem import Descriptors
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
-
import numpy as np
|
8 |
|
9 |
-
# Load tokenizer and model
|
10 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
11 |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
12 |
|
13 |
-
# Load saved scalers
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
-
# Descriptor
|
22 |
-
def compute_descriptors(smiles):
|
23 |
mol = Chem.MolFromSmiles(smiles)
|
24 |
if mol is None:
|
25 |
-
raise ValueError("Invalid SMILES")
|
26 |
-
|
27 |
-
|
28 |
Descriptors.MolWt(mol),
|
29 |
Descriptors.MolLogP(mol),
|
30 |
Descriptors.TPSA(mol),
|
@@ -35,9 +37,10 @@ def compute_descriptors(smiles):
|
|
35 |
Descriptors.HeavyAtomCount(mol),
|
36 |
Descriptors.RingCount(mol),
|
37 |
Descriptors.MolMR(mol)
|
38 |
-
]
|
|
|
39 |
|
40 |
-
#
|
41 |
class TransformerRegressor(nn.Module):
|
42 |
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
|
43 |
super().__init__()
|
@@ -55,49 +58,48 @@ class TransformerRegressor(nn.Module):
|
|
55 |
def forward(self, x):
|
56 |
x = self.feat_proj(x)
|
57 |
x = self.transformer_encoder(x)
|
58 |
-
x = x.mean(dim=1)
|
59 |
return self.regression_head(x)
|
60 |
|
61 |
-
#
|
62 |
-
input_dim = 768
|
63 |
hidden_dim = 256
|
64 |
num_layers = 2
|
65 |
-
output_dim = 6
|
66 |
|
67 |
-
# Load model
|
|
|
68 |
model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
|
69 |
-
model.load_state_dict(torch.load("transformer_model.pt", map_location=
|
70 |
model.eval()
|
71 |
|
72 |
-
#
|
73 |
-
def predict_properties(smiles):
|
74 |
try:
|
75 |
-
|
76 |
-
|
77 |
|
78 |
-
#
|
79 |
inputs = tokenizer(smiles, return_tensors="pt")
|
80 |
with torch.no_grad():
|
81 |
outputs = embedding_model(**inputs)
|
82 |
-
|
83 |
|
84 |
-
# Forward pass
|
85 |
with torch.no_grad():
|
86 |
-
preds = model(
|
87 |
|
88 |
preds_np = preds.numpy()
|
|
|
|
|
|
|
89 |
preds_rescaled = np.concatenate([
|
90 |
-
|
91 |
-
|
92 |
-
scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
|
93 |
-
scaler_logp.inverse_transform(preds_np[:, [3]]),
|
94 |
-
scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
|
95 |
-
scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
|
96 |
], axis=1)
|
97 |
|
98 |
-
|
99 |
-
results =
|
100 |
-
|
101 |
return results
|
102 |
|
103 |
except Exception as e:
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import joblib
|
4 |
+
import numpy as np
|
5 |
from rdkit import Chem
|
6 |
from rdkit.Chem import Descriptors
|
7 |
from transformers import AutoTokenizer, AutoModel
|
|
|
8 |
|
9 |
+
# Load ChemBERTa tokenizer and model
|
10 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
11 |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
12 |
|
13 |
+
# Load saved scalers for inverse transformations
|
14 |
+
scalers = {
|
15 |
+
"Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
|
16 |
+
"Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
|
17 |
+
"Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
|
18 |
+
"logP": joblib.load("scaler_LogP.joblib"),
|
19 |
+
"Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
|
20 |
+
"Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
|
21 |
+
}
|
22 |
|
23 |
+
# Descriptor calculation
|
24 |
+
def compute_descriptors(smiles: str):
|
25 |
mol = Chem.MolFromSmiles(smiles)
|
26 |
if mol is None:
|
27 |
+
raise ValueError("Invalid SMILES string.")
|
28 |
+
|
29 |
+
descriptors = [
|
30 |
Descriptors.MolWt(mol),
|
31 |
Descriptors.MolLogP(mol),
|
32 |
Descriptors.TPSA(mol),
|
|
|
37 |
Descriptors.HeavyAtomCount(mol),
|
38 |
Descriptors.RingCount(mol),
|
39 |
Descriptors.MolMR(mol)
|
40 |
+
]
|
41 |
+
return np.array(descriptors, dtype=np.float32)
|
42 |
|
43 |
+
# Transformer regression model definition
|
44 |
class TransformerRegressor(nn.Module):
|
45 |
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
|
46 |
super().__init__()
|
|
|
58 |
def forward(self, x):
|
59 |
x = self.feat_proj(x)
|
60 |
x = self.transformer_encoder(x)
|
61 |
+
x = x.mean(dim=1) # Global average pooling
|
62 |
return self.regression_head(x)
|
63 |
|
64 |
+
# Model hyperparameters (must match training)
|
65 |
+
input_dim = 768
|
66 |
hidden_dim = 256
|
67 |
num_layers = 2
|
68 |
+
output_dim = 6
|
69 |
|
70 |
+
# Load trained model
|
71 |
+
device = torch.device("cpu")
|
72 |
model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
|
73 |
+
model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
|
74 |
model.eval()
|
75 |
|
76 |
+
# Prediction function
|
77 |
+
def predict_properties(smiles: str):
|
78 |
try:
|
79 |
+
# Validate SMILES and compute descriptors
|
80 |
+
_ = compute_descriptors(smiles)
|
81 |
|
82 |
+
# ChemBERTa embedding (CLS token)
|
83 |
inputs = tokenizer(smiles, return_tensors="pt")
|
84 |
with torch.no_grad():
|
85 |
outputs = embedding_model(**inputs)
|
86 |
+
embedding = outputs.last_hidden_state[:, 0, :] # Shape: (1, 768)
|
87 |
|
88 |
+
# Forward pass through model
|
89 |
with torch.no_grad():
|
90 |
+
preds = model(embedding)
|
91 |
|
92 |
preds_np = preds.numpy()
|
93 |
+
|
94 |
+
# Inverse transform each property
|
95 |
+
keys = list(scalers.keys())
|
96 |
preds_rescaled = np.concatenate([
|
97 |
+
scalers[keys[i]].inverse_transform(preds_np[:, [i]])
|
98 |
+
for i in range(output_dim)
|
|
|
|
|
|
|
|
|
99 |
], axis=1)
|
100 |
|
101 |
+
# Create dictionary of results
|
102 |
+
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
|
|
|
103 |
return results
|
104 |
|
105 |
except Exception as e:
|