Transpolymer2 / prediction.py
transpolymer's picture
Update prediction.py
eea9e94 verified
raw
history blame
3.94 kB
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np
# Load tokenizer and model for embeddings
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load saved scalers (for inverse_transform)
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib")
scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib")
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib")
scaler_logp = joblib.load("scaler_LogP.joblib")
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib")
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
# Descriptor function with exact order from training
def compute_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES")
return np.array([
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
], dtype=np.float32)
# Define your model class exactly like in training
class TransformerRegressor(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super().__init__()
self.feat_proj = nn.Linear(input_dim, hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(hidden_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, output_dim)
)
def forward(self, x):
x = self.feat_proj(x)
x = self.transformer_encoder(x)
x = x.mean(dim=1)
return self.regression_head(x)
# Set model hyperparameters (must match training config)
input_dim = 768 # ChemBERTa embedding size
hidden_dim = 256
num_layers = 2
output_dim = 6 # Number of properties predicted
# Load model
model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()
# Main prediction function
def predict_properties(smiles):
try:
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)
# Get embedding
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
emb = outputs.last_hidden_state[:, 0, :] # CLS token output (1, 768)
# Forward pass
with torch.no_grad():
preds = model(emb)
preds_np = preds.numpy()
preds_rescaled = np.concatenate([
scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
scaler_logp.inverse_transform(preds_np[:, [3]]),
scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
], axis=1)
keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
results = dict(zip(keys, preds_rescaled.flatten().round(4)))
return results
except Exception as e:
return {"error": str(e)}