Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +89 -80
prediction.py
CHANGED
@@ -1,88 +1,97 @@
|
|
1 |
-
import streamlit as st
|
2 |
import torch
|
|
|
3 |
import joblib
|
4 |
-
import pandas as pd
|
5 |
-
import numpy as np
|
6 |
from rdkit import Chem
|
7 |
-
from rdkit.Chem import
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
-
import
|
10 |
|
11 |
-
# Load
|
12 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
target_scaler_path = os.path.join(model_dir, "target_scaler.pkl")
|
29 |
-
target_scaler = joblib.load(target_scaler_path) if os.path.exists(target_scaler_path) else None
|
30 |
-
|
31 |
-
# Properties
|
32 |
-
PROPERTY_NAMES = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
|
33 |
-
|
34 |
-
def smiles_to_fingerprint(smiles, radius=2, nBits=2048):
|
35 |
mol = Chem.MolFromSmiles(smiles)
|
36 |
if mol is None:
|
37 |
-
|
38 |
-
return np.array(
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torch.nn as nn
|
3 |
import joblib
|
|
|
|
|
4 |
from rdkit import Chem
|
5 |
+
from rdkit.Chem import Descriptors
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
+
import numpy as np
|
8 |
|
9 |
+
# Load tokenizer and embedding model
|
10 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
11 |
+
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
12 |
+
|
13 |
+
# Load individual scalers
|
14 |
+
target_keys = [
|
15 |
+
"Tensile_strength(Mpa)",
|
16 |
+
"Ionization_Energy(eV)",
|
17 |
+
"Electron_Affinity(eV)",
|
18 |
+
"LogP",
|
19 |
+
"Refractive_Index",
|
20 |
+
"Molecular_Weight(g/mol)"
|
21 |
+
]
|
22 |
+
scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]
|
23 |
+
|
24 |
+
# Descriptor function (must match training order)
|
25 |
+
def compute_descriptors(smiles):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
mol = Chem.MolFromSmiles(smiles)
|
27 |
if mol is None:
|
28 |
+
raise ValueError("Invalid SMILES string.")
|
29 |
+
return np.array([
|
30 |
+
Descriptors.MolWt(mol),
|
31 |
+
Descriptors.MolLogP(mol),
|
32 |
+
Descriptors.TPSA(mol),
|
33 |
+
Descriptors.NumRotatableBonds(mol),
|
34 |
+
Descriptors.NumHDonors(mol),
|
35 |
+
Descriptors.NumHAcceptors(mol),
|
36 |
+
Descriptors.FractionCSP3(mol),
|
37 |
+
Descriptors.HeavyAtomCount(mol),
|
38 |
+
Descriptors.RingCount(mol),
|
39 |
+
Descriptors.MolMR(mol)
|
40 |
+
], dtype=np.float32)
|
41 |
+
|
42 |
+
# Model class must match training
|
43 |
+
class TransformerRegressor(nn.Module):
|
44 |
+
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
|
45 |
+
super().__init__()
|
46 |
+
self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
|
47 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
|
48 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
49 |
+
self.regressor = nn.Sequential(
|
50 |
+
nn.Flatten(),
|
51 |
+
nn.Linear(2 * d_model, 256),
|
52 |
+
nn.ReLU(),
|
53 |
+
nn.Linear(256, num_targets)
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, embedding, descriptors):
|
57 |
+
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
|
58 |
+
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
|
59 |
+
encoded = self.transformer(stacked) # (B, 2, d_model)
|
60 |
+
return self.regressor(encoded)
|
61 |
+
|
62 |
+
# Load trained model
|
63 |
+
model = TransformerRegressor()
|
64 |
+
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
# Main prediction function
|
68 |
+
def predict_properties(smiles):
|
69 |
+
try:
|
70 |
+
# Compute descriptors
|
71 |
+
descriptors = compute_descriptors(smiles)
|
72 |
+
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
|
73 |
+
|
74 |
+
# Get embedding from ChemBERTa
|
75 |
+
inputs = tokenizer(smiles, return_tensors="pt")
|
76 |
+
with torch.no_grad():
|
77 |
+
outputs = embedding_model(**inputs)
|
78 |
+
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
|
79 |
+
|
80 |
+
# Predict
|
81 |
+
with torch.no_grad():
|
82 |
+
preds = model(embedding, descriptors_tensor)
|
83 |
+
|
84 |
+
# Inverse transform each prediction
|
85 |
+
preds_np = preds.numpy().flatten()
|
86 |
+
preds_rescaled = [
|
87 |
+
scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers))
|
88 |
+
]
|
89 |
+
|
90 |
+
# Prepare results
|
91 |
+
readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
|
92 |
+
results = dict(zip(readable_keys, np.round(preds_rescaled, 4)))
|
93 |
+
|
94 |
+
return results
|
95 |
+
|
96 |
+
except Exception as e:
|
97 |
+
return {"error": str(e)}
|