Spaces:
Sleeping
Sleeping
File size: 3,937 Bytes
84dad8f 5e9e549 c621eb3 84dad8f 5e9e549 c621eb3 5e9e549 c621eb3 15f5470 c621eb3 5e9e549 15f5470 eea9e94 5e9e549 15f5470 5e9e549 84dad8f c621eb3 15f5470 5e9e549 15f5470 5e9e549 eea9e94 5e9e549 eea9e94 5e9e549 eea9e94 5e9e549 eea9e94 5e9e549 15f5470 eea9e94 5e9e549 eea9e94 5e9e549 15f5470 5e9e549 eea9e94 5e9e549 15f5470 5e9e549 eea9e94 5e9e549 15f5470 5e9e549 15f5470 5e9e549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np
# Load tokenizer and model for embeddings
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load saved scalers (for inverse_transform)
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib")
scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib")
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib")
scaler_logp = joblib.load("scaler_LogP.joblib")
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib")
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
# Descriptor function with exact order from training
def compute_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES")
return np.array([
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
], dtype=np.float32)
# Define your model class exactly like in training
class TransformerRegressor(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super().__init__()
self.feat_proj = nn.Linear(input_dim, hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(hidden_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, output_dim)
)
def forward(self, x):
x = self.feat_proj(x)
x = self.transformer_encoder(x)
x = x.mean(dim=1)
return self.regression_head(x)
# Set model hyperparameters (must match training config)
input_dim = 768 # ChemBERTa embedding size
hidden_dim = 256
num_layers = 2
output_dim = 6 # Number of properties predicted
# Load model
model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()
# Main prediction function
def predict_properties(smiles):
try:
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)
# Get embedding
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
emb = outputs.last_hidden_state[:, 0, :] # CLS token output (1, 768)
# Forward pass
with torch.no_grad():
preds = model(emb)
preds_np = preds.numpy()
preds_rescaled = np.concatenate([
scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
scaler_logp.inverse_transform(preds_np[:, [3]]),
scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
], axis=1)
keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
results = dict(zip(keys, preds_rescaled.flatten().round(4)))
return results
except Exception as e:
return {"error": str(e)} |