Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import joblib | |
from rdkit import Chem | |
from rdkit.Chem import Descriptors | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
# Load tokenizer and model for embeddings | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
# Load saved scalers (for inverse_transform) | |
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib") | |
scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib") | |
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib") | |
scaler_logp = joblib.load("scaler_LogP.joblib") | |
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib") | |
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib") | |
# Descriptor function with exact order from training | |
def compute_descriptors(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
raise ValueError("Invalid SMILES") | |
return np.array([ | |
Descriptors.MolWt(mol), | |
Descriptors.MolLogP(mol), | |
Descriptors.TPSA(mol), | |
Descriptors.NumRotatableBonds(mol), | |
Descriptors.NumHDonors(mol), | |
Descriptors.NumHAcceptors(mol), | |
Descriptors.FractionCSP3(mol), | |
Descriptors.HeavyAtomCount(mol), | |
Descriptors.RingCount(mol), | |
Descriptors.MolMR(mol) | |
], dtype=np.float32) | |
# Define your model class exactly like in training | |
class TransformerRegressor(nn.Module): | |
def __init__(self, input_dim, hidden_dim, num_layers, output_dim): | |
super().__init__() | |
self.feat_proj = nn.Linear(input_dim, hidden_dim) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.regression_head = nn.Sequential( | |
nn.Linear(hidden_dim, 128), | |
nn.ReLU(), | |
nn.Linear(128, 64), | |
nn.ReLU(), | |
nn.Linear(64, output_dim) | |
) | |
def forward(self, x): | |
x = self.feat_proj(x) | |
x = self.transformer_encoder(x) | |
x = x.mean(dim=1) | |
return self.regression_head(x) | |
# Set model hyperparameters (must match training config) | |
input_dim = 768 # ChemBERTa embedding size | |
hidden_dim = 256 | |
num_layers = 2 | |
output_dim = 6 # Number of properties predicted | |
# Load model | |
model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim) | |
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu"))) | |
model.eval() | |
# Main prediction function | |
def predict_properties(smiles): | |
try: | |
descriptors = compute_descriptors(smiles) | |
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) | |
# Get embedding | |
inputs = tokenizer(smiles, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = embedding_model(**inputs) | |
emb = outputs.last_hidden_state[:, 0, :] # CLS token output (1, 768) | |
# Forward pass | |
with torch.no_grad(): | |
preds = model(emb) | |
preds_np = preds.numpy() | |
preds_rescaled = np.concatenate([ | |
scaler_tensile_strength.inverse_transform(preds_np[:, [0]]), | |
scaler_ionization_energy.inverse_transform(preds_np[:, [1]]), | |
scaler_electron_affinity.inverse_transform(preds_np[:, [2]]), | |
scaler_logp.inverse_transform(preds_np[:, [3]]), | |
scaler_refractive_index.inverse_transform(preds_np[:, [4]]), | |
scaler_molecular_weight.inverse_transform(preds_np[:, [5]]) | |
], axis=1) | |
keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"] | |
results = dict(zip(keys, preds_rescaled.flatten().round(4))) | |
return results | |
except Exception as e: | |
return {"error": str(e)} |