Spaces:
Sleeping
Sleeping
File size: 3,607 Bytes
84dad8f 5e9e549 c621eb3 84dad8f 5e9e549 c621eb3 5e9e549 c621eb3 5e9e549 c621eb3 5e9e549 84dad8f c621eb3 5e9e549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np
# Load tokenizer and embedding model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load individual scalers
target_keys = [
"Tensile_strength(Mpa)",
"Ionization_Energy(eV)",
"Electron_Affinity(eV)",
"LogP",
"Refractive_Index",
"Molecular_Weight(g/mol)"
]
scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]
# Descriptor function (must match training order)
def compute_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string.")
return np.array([
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
], dtype=np.float32)
# Model class must match training
class TransformerRegressor(nn.Module):
def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
super().__init__()
self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(2 * d_model, 256),
nn.ReLU(),
nn.Linear(256, num_targets)
)
def forward(self, embedding, descriptors):
desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
encoded = self.transformer(stacked) # (B, 2, d_model)
return self.regressor(encoded)
# Load trained model
model = TransformerRegressor()
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()
# Main prediction function
def predict_properties(smiles):
try:
# Compute descriptors
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
# Get embedding from ChemBERTa
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
# Predict
with torch.no_grad():
preds = model(embedding, descriptors_tensor)
# Inverse transform each prediction
preds_np = preds.numpy().flatten()
preds_rescaled = [
scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers))
]
# Prepare results
readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
results = dict(zip(readable_keys, np.round(preds_rescaled, 4)))
return results
except Exception as e:
return {"error": str(e)} |