Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import joblib | |
from rdkit import Chem | |
from rdkit.Chem import Descriptors | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
# Load tokenizer and embedding model | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
# Load individual scalers | |
target_keys = [ | |
"Tensile_strength(Mpa)", | |
"Ionization_Energy(eV)", | |
"Electron_Affinity(eV)", | |
"LogP", | |
"Refractive_Index", | |
"Molecular_Weight(g/mol)" | |
] | |
scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys] | |
# Descriptor function (must match training order) | |
def compute_descriptors(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
raise ValueError("Invalid SMILES string.") | |
return np.array([ | |
Descriptors.MolWt(mol), | |
Descriptors.MolLogP(mol), | |
Descriptors.TPSA(mol), | |
Descriptors.NumRotatableBonds(mol), | |
Descriptors.NumHDonors(mol), | |
Descriptors.NumHAcceptors(mol), | |
Descriptors.FractionCSP3(mol), | |
Descriptors.HeavyAtomCount(mol), | |
Descriptors.RingCount(mol), | |
Descriptors.MolMR(mol) | |
], dtype=np.float32) | |
# Model class must match training | |
class TransformerRegressor(nn.Module): | |
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6): | |
super().__init__() | |
self.descriptor_proj = nn.Linear(descriptor_dim, d_model) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) | |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.regressor = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(2 * d_model, 256), | |
nn.ReLU(), | |
nn.Linear(256, num_targets) | |
) | |
def forward(self, embedding, descriptors): | |
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model) | |
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model) | |
encoded = self.transformer(stacked) # (B, 2, d_model) | |
return self.regressor(encoded) | |
# Load trained model | |
model = TransformerRegressor() | |
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu"))) | |
model.eval() | |
# Main prediction function | |
def predict_properties(smiles): | |
try: | |
# Compute descriptors | |
descriptors = compute_descriptors(smiles) | |
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10) | |
# Get embedding from ChemBERTa | |
inputs = tokenizer(smiles, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = embedding_model(**inputs) | |
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768) | |
# Predict | |
with torch.no_grad(): | |
preds = model(embedding, descriptors_tensor) | |
# Inverse transform each prediction | |
preds_np = preds.numpy().flatten() | |
preds_rescaled = [ | |
scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers)) | |
] | |
# Prepare results | |
readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"] | |
results = dict(zip(readable_keys, np.round(preds_rescaled, 4))) | |
return results | |
except Exception as e: | |
return {"error": str(e)} |