import torch import torch.nn as nn import joblib from rdkit import Chem from rdkit.Chem import Descriptors from transformers import AutoTokenizer, AutoModel import numpy as np # Load tokenizer and embedding model tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") # Load individual scalers target_keys = [ "Tensile_strength(Mpa)", "Ionization_Energy(eV)", "Electron_Affinity(eV)", "LogP", "Refractive_Index", "Molecular_Weight(g/mol)" ] scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys] # Descriptor function (must match training order) def compute_descriptors(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError("Invalid SMILES string.") return np.array([ Descriptors.MolWt(mol), Descriptors.MolLogP(mol), Descriptors.TPSA(mol), Descriptors.NumRotatableBonds(mol), Descriptors.NumHDonors(mol), Descriptors.NumHAcceptors(mol), Descriptors.FractionCSP3(mol), Descriptors.HeavyAtomCount(mol), Descriptors.RingCount(mol), Descriptors.MolMR(mol) ], dtype=np.float32) # Model class must match training class TransformerRegressor(nn.Module): def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6): super().__init__() self.descriptor_proj = nn.Linear(descriptor_dim, d_model) encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.regressor = nn.Sequential( nn.Flatten(), nn.Linear(2 * d_model, 256), nn.ReLU(), nn.Linear(256, num_targets) ) def forward(self, embedding, descriptors): desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model) stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model) encoded = self.transformer(stacked) # (B, 2, d_model) return self.regressor(encoded) # Load trained model model = TransformerRegressor() model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu"))) model.eval() # Main prediction function def predict_properties(smiles): try: # Compute descriptors descriptors = compute_descriptors(smiles) descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10) # Get embedding from ChemBERTa inputs = tokenizer(smiles, return_tensors="pt") with torch.no_grad(): outputs = embedding_model(**inputs) embedding = outputs.last_hidden_state[:, 0, :] # (1, 768) # Predict with torch.no_grad(): preds = model(embedding, descriptors_tensor) # Inverse transform each prediction preds_np = preds.numpy().flatten() preds_rescaled = [ scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers)) ] # Prepare results readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"] results = dict(zip(readable_keys, np.round(preds_rescaled, 4))) return results except Exception as e: return {"error": str(e)}