Transpolymer2 / prediction.py
transpolymer's picture
Update prediction.py
5e9e549 verified
raw
history blame
3.61 kB
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np
# Load tokenizer and embedding model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load individual scalers
target_keys = [
"Tensile_strength(Mpa)",
"Ionization_Energy(eV)",
"Electron_Affinity(eV)",
"LogP",
"Refractive_Index",
"Molecular_Weight(g/mol)"
]
scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]
# Descriptor function (must match training order)
def compute_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string.")
return np.array([
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
], dtype=np.float32)
# Model class must match training
class TransformerRegressor(nn.Module):
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
super().__init__()
self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(2 * d_model, 256),
nn.ReLU(),
nn.Linear(256, num_targets)
)
def forward(self, embedding, descriptors):
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
encoded = self.transformer(stacked) # (B, 2, d_model)
return self.regressor(encoded)
# Load trained model
model = TransformerRegressor()
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()
# Main prediction function
def predict_properties(smiles):
try:
# Compute descriptors
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
# Get embedding from ChemBERTa
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
# Predict
with torch.no_grad():
preds = model(embedding, descriptors_tensor)
# Inverse transform each prediction
preds_np = preds.numpy().flatten()
preds_rescaled = [
scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers))
]
# Prepare results
readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
results = dict(zip(readable_keys, np.round(preds_rescaled, 4)))
return results
except Exception as e:
return {"error": str(e)}