Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import joblib | |
import numpy as np | |
from rdkit import Chem | |
from rdkit.Chem import Descriptors | |
from transformers import AutoTokenizer, AutoModel | |
# Load ChemBERTa tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
# Load saved scalers for inverse transformations | |
scalers = { | |
"Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"), | |
"Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"), | |
"Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"), | |
"logP": joblib.load("scaler_LogP.joblib"), | |
"Refractive Index": joblib.load("scaler_Refractive_Index.joblib"), | |
"Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib") | |
} | |
# Descriptor calculation | |
def compute_descriptors(smiles: str): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
raise ValueError("Invalid SMILES string.") | |
descriptors = [ | |
Descriptors.MolWt(mol), | |
Descriptors.MolLogP(mol), | |
Descriptors.TPSA(mol), | |
Descriptors.NumRotatableBonds(mol), | |
Descriptors.NumHDonors(mol), | |
Descriptors.NumHAcceptors(mol), | |
Descriptors.FractionCSP3(mol), | |
Descriptors.HeavyAtomCount(mol), | |
Descriptors.RingCount(mol), | |
Descriptors.MolMR(mol) | |
] | |
return np.array(descriptors, dtype=np.float32) | |
# Transformer regression model definition (must match training) | |
class TransformerRegressor(nn.Module): | |
def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim): | |
super().__init__() | |
self.feat_proj = nn.Linear(input_dim, embedding_dim) | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=embedding_dim, | |
nhead=8, | |
dim_feedforward=ff_dim, | |
dropout=0.1, | |
batch_first=True | |
) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.regression_head = nn.Sequential( | |
nn.Linear(embedding_dim, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, output_dim) | |
) | |
def forward(self, x): | |
x = self.feat_proj(x) | |
x = self.transformer_encoder(x) | |
x = x.mean(dim=1) | |
return self.regression_head(x) | |
# Model hyperparameters (must match training) | |
embedding_dim = 768 | |
descriptor_dim = 1290 # Based on earlier errors. If unsure, use 1290 | |
input_dim = embedding_dim + descriptor_dim # 768 + 1290 = 2058 | |
ff_dim = 1024 | |
num_layers = 2 | |
output_dim = 6 | |
# Load trained model | |
device = torch.device("cpu") | |
model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim) | |
model.load_state_dict(torch.load("transformer_model.pt", map_location=device)) | |
model.eval() | |
# Prediction function | |
def predict_properties(smiles: str): | |
try: | |
# Compute descriptors | |
descriptors = compute_descriptors(smiles) | |
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0) | |
# Get ChemBERTa embedding (CLS token) | |
inputs = tokenizer(smiles, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = embedding_model(**inputs) | |
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768) | |
# Combine features | |
combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # Shape: (1, 1, 2058) | |
# Forward pass | |
with torch.no_grad(): | |
preds = model(combined) | |
preds_np = preds.numpy() | |
# Inverse transform each property | |
keys = list(scalers.keys()) | |
preds_rescaled = np.concatenate([ | |
scalers[keys[i]].inverse_transform(preds_np[:, [i]]) | |
for i in range(output_dim) | |
], axis=1) | |
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())} | |
return results | |
except Exception as e: | |
return {"error": str(e)} | |
# Show function to print the results | |
def show(smiles: str): | |
result = predict_properties(smiles) | |
if "error" in result: | |
print(f"Error: {result['error']}") | |
else: | |
print("Predicted Properties for SMILES:", smiles) | |
for key, value in result.items(): | |
print(f"{key}: {value}") |