Transpolymer2 / prediction.py
transpolymer's picture
Update prediction.py
15f5470 verified
raw
history blame
4.22 kB
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np
# Load tokenizer and model for embeddings
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load saved scalers (for inverse_transform)
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib") # Scaler for Tensile Strength
scaler_ionization_energy = joblib.load("scaler_lonization_Energy_eV_.joblib") # Scaler for Ionization Energy
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib") # Scaler for Electron Affinity
scaler_logp = joblib.load("scaler_LogP.joblib") # Scaler for LogP
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib") # Scaler for Refractive Index
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib") # Scaler for Molecular Weight
# Descriptor function with exact order from training
def compute_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES")
return np.array([
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
], dtype=np.float32)
# Define your model class exactly like in training
class TransformerRegressor(nn.Module):
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
super().__init__()
self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(2 * d_model, 256),
nn.ReLU(),
nn.Linear(256, num_targets)
)
def forward(self, embedding, descriptors):
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
encoded = self.transformer(stacked) # (B, 2, d_model)
output = self.regressor(encoded)
return output
# Load model
model = TransformerRegressor()
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()
# Main prediction function
def predict_properties(smiles):
try:
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
# Get embedding
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
emb = outputs.last_hidden_state[:, 0, :] # [CLS] token, shape (1, 768)
# Forward pass
with torch.no_grad():
preds = model(emb, descriptors_tensor)
# Inverse transform predictions using respective scalers
preds_np = preds.numpy()
preds_rescaled = np.concatenate([
scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
scaler_logp.inverse_transform(preds_np[:, [3]]),
scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
], axis=1)
# Round and format
keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
results = dict(zip(keys, preds_rescaled.flatten().round(4)))
return results
except Exception as e:
return {"error": str(e)}