Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +32 -29
prediction.py
CHANGED
@@ -6,26 +6,24 @@ from rdkit.Chem import Descriptors
|
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
import numpy as np
|
8 |
|
9 |
-
# Load tokenizer and
|
10 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
11 |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
12 |
|
13 |
-
# Load
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
"Molecular_Weight(g/mol)"
|
21 |
-
]
|
22 |
-
scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]
|
23 |
|
24 |
-
# Descriptor function
|
25 |
def compute_descriptors(smiles):
|
26 |
mol = Chem.MolFromSmiles(smiles)
|
27 |
if mol is None:
|
28 |
-
raise ValueError("Invalid SMILES
|
|
|
29 |
return np.array([
|
30 |
Descriptors.MolWt(mol),
|
31 |
Descriptors.MolLogP(mol),
|
@@ -39,7 +37,7 @@ def compute_descriptors(smiles):
|
|
39 |
Descriptors.MolMR(mol)
|
40 |
], dtype=np.float32)
|
41 |
|
42 |
-
#
|
43 |
class TransformerRegressor(nn.Module):
|
44 |
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
|
45 |
super().__init__()
|
@@ -57,9 +55,10 @@ class TransformerRegressor(nn.Module):
|
|
57 |
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
|
58 |
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
|
59 |
encoded = self.transformer(stacked) # (B, 2, d_model)
|
60 |
-
|
|
|
61 |
|
62 |
-
# Load
|
63 |
model = TransformerRegressor()
|
64 |
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
|
65 |
model.eval()
|
@@ -67,29 +66,33 @@ model.eval()
|
|
67 |
# Main prediction function
|
68 |
def predict_properties(smiles):
|
69 |
try:
|
70 |
-
# Compute descriptors
|
71 |
descriptors = compute_descriptors(smiles)
|
72 |
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
|
73 |
|
74 |
-
# Get embedding
|
75 |
inputs = tokenizer(smiles, return_tensors="pt")
|
76 |
with torch.no_grad():
|
77 |
outputs = embedding_model(**inputs)
|
78 |
-
|
79 |
|
80 |
-
#
|
81 |
with torch.no_grad():
|
82 |
-
preds = model(
|
83 |
|
84 |
-
# Inverse transform
|
85 |
-
preds_np = preds.numpy()
|
86 |
-
preds_rescaled = [
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
#
|
91 |
-
|
92 |
-
results = dict(zip(
|
93 |
|
94 |
return results
|
95 |
|
|
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
import numpy as np
|
8 |
|
9 |
+
# Load tokenizer and model for embeddings
|
10 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
11 |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
12 |
|
13 |
+
# Load saved scalers (for inverse_transform)
|
14 |
+
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib") # Scaler for Tensile Strength
|
15 |
+
scaler_ionization_energy = joblib.load("scaler_lonization_Energy_eV_.joblib") # Scaler for Ionization Energy
|
16 |
+
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib") # Scaler for Electron Affinity
|
17 |
+
scaler_logp = joblib.load("scaler_LogP.joblib") # Scaler for LogP
|
18 |
+
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib") # Scaler for Refractive Index
|
19 |
+
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib") # Scaler for Molecular Weight
|
|
|
|
|
|
|
20 |
|
21 |
+
# Descriptor function with exact order from training
|
22 |
def compute_descriptors(smiles):
|
23 |
mol = Chem.MolFromSmiles(smiles)
|
24 |
if mol is None:
|
25 |
+
raise ValueError("Invalid SMILES")
|
26 |
+
|
27 |
return np.array([
|
28 |
Descriptors.MolWt(mol),
|
29 |
Descriptors.MolLogP(mol),
|
|
|
37 |
Descriptors.MolMR(mol)
|
38 |
], dtype=np.float32)
|
39 |
|
40 |
+
# Define your model class exactly like in training
|
41 |
class TransformerRegressor(nn.Module):
|
42 |
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
|
43 |
super().__init__()
|
|
|
55 |
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
|
56 |
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
|
57 |
encoded = self.transformer(stacked) # (B, 2, d_model)
|
58 |
+
output = self.regressor(encoded)
|
59 |
+
return output
|
60 |
|
61 |
+
# Load model
|
62 |
model = TransformerRegressor()
|
63 |
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
|
64 |
model.eval()
|
|
|
66 |
# Main prediction function
|
67 |
def predict_properties(smiles):
|
68 |
try:
|
|
|
69 |
descriptors = compute_descriptors(smiles)
|
70 |
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
|
71 |
|
72 |
+
# Get embedding
|
73 |
inputs = tokenizer(smiles, return_tensors="pt")
|
74 |
with torch.no_grad():
|
75 |
outputs = embedding_model(**inputs)
|
76 |
+
emb = outputs.last_hidden_state[:, 0, :] # [CLS] token, shape (1, 768)
|
77 |
|
78 |
+
# Forward pass
|
79 |
with torch.no_grad():
|
80 |
+
preds = model(emb, descriptors_tensor)
|
81 |
|
82 |
+
# Inverse transform predictions using respective scalers
|
83 |
+
preds_np = preds.numpy()
|
84 |
+
preds_rescaled = np.concatenate([
|
85 |
+
scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
|
86 |
+
scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
|
87 |
+
scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
|
88 |
+
scaler_logp.inverse_transform(preds_np[:, [3]]),
|
89 |
+
scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
|
90 |
+
scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
|
91 |
+
], axis=1)
|
92 |
|
93 |
+
# Round and format
|
94 |
+
keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
|
95 |
+
results = dict(zip(keys, preds_rescaled.flatten().round(4)))
|
96 |
|
97 |
return results
|
98 |
|