Transpolymer2 / prediction.py
transpolymer's picture
Update prediction.py
c3f06b5 verified
raw
history blame
4.39 kB
import torch
import torch.nn as nn
import joblib
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
# Load ChemBERTa tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load saved scalers for inverse transformations
scalers = {
"Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
"Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
"Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
"logP": joblib.load("scaler_LogP.joblib"),
"Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
"Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}
# Descriptor calculation
def compute_descriptors(smiles: str):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string.")
descriptors = [
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
]
return np.array(descriptors, dtype=np.float32)
# Transformer regression model definition (must match training)
class TransformerRegressor(nn.Module):
def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim):
super().__init__()
self.feat_proj = nn.Linear(input_dim, embedding_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=8,
dim_feedforward=ff_dim,
dropout=0.1,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(embedding_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
x = self.feat_proj(x)
x = self.transformer_encoder(x)
x = x.mean(dim=1)
return self.regression_head(x)
# Model hyperparameters (must match training)
embedding_dim = 768
descriptor_dim = 1290 # Based on earlier errors. If unsure, use 1290
input_dim = embedding_dim + descriptor_dim # 768 + 1290 = 2058
ff_dim = 1024
num_layers = 2
output_dim = 6
# Load trained model
device = torch.device("cpu")
model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim)
model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
model.eval()
# Prediction function
def predict_properties(smiles: str):
try:
# Compute descriptors
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
# Get ChemBERTa embedding (CLS token)
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
# Combine features
combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # Shape: (1, 1, 2058)
# Forward pass
with torch.no_grad():
preds = model(combined)
preds_np = preds.numpy()
# Inverse transform each property
keys = list(scalers.keys())
preds_rescaled = np.concatenate([
scalers[keys[i]].inverse_transform(preds_np[:, [i]])
for i in range(output_dim)
], axis=1)
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
return results
except Exception as e:
return {"error": str(e)}
# Show function to print the results
def show(smiles: str):
result = predict_properties(smiles)
if "error" in result:
print(f"Error: {result['error']}")
else:
print("Predicted Properties for SMILES:", smiles)
for key, value in result.items():
print(f"{key}: {value}")