Spaces:
Sleeping
Sleeping
File size: 4,391 Bytes
84dad8f 5e9e549 c621eb3 b73e3e2 84dad8f 5e9e549 c621eb3 b73e3e2 c621eb3 5e9e549 b73e3e2 5e9e549 b73e3e2 84dad8f c621eb3 b73e3e2 5e9e549 b73e3e2 5e9e549 3de6f45 5e9e549 3de6f45 5e9e549 3de6f45 eea9e94 3de6f45 5e9e549 3de6f45 eea9e94 3de6f45 5e9e549 eea9e94 3de6f45 eea9e94 b73e3e2 3de6f45 eea9e94 3de6f45 5e9e549 b73e3e2 3de6f45 b73e3e2 5e9e549 b73e3e2 5e9e549 3de6f45 5e9e549 3de6f45 5e9e549 3de6f45 5e9e549 3de6f45 5e9e549 3de6f45 5e9e549 15f5470 b73e3e2 15f5470 b73e3e2 15f5470 5e9e549 b73e3e2 5e9e549 c3f06b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
import joblib
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
# Load ChemBERTa tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
# Load saved scalers for inverse transformations
scalers = {
"Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
"Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
"Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
"logP": joblib.load("scaler_LogP.joblib"),
"Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
"Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}
# Descriptor calculation
def compute_descriptors(smiles: str):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string.")
descriptors = [
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
]
return np.array(descriptors, dtype=np.float32)
# Transformer regression model definition (must match training)
class TransformerRegressor(nn.Module):
def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim):
super().__init__()
self.feat_proj = nn.Linear(input_dim, embedding_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=8,
dim_feedforward=ff_dim,
dropout=0.1,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(embedding_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
x = self.feat_proj(x)
x = self.transformer_encoder(x)
x = x.mean(dim=1)
return self.regression_head(x)
# Model hyperparameters (must match training)
embedding_dim = 768
descriptor_dim = 1290 # Based on earlier errors. If unsure, use 1290
input_dim = embedding_dim + descriptor_dim # 768 + 1290 = 2058
ff_dim = 1024
num_layers = 2
output_dim = 6
# Load trained model
device = torch.device("cpu")
model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim)
model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
model.eval()
# Prediction function
def predict_properties(smiles: str):
try:
# Compute descriptors
descriptors = compute_descriptors(smiles)
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
# Get ChemBERTa embedding (CLS token)
inputs = tokenizer(smiles, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
# Combine features
combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # Shape: (1, 1, 2058)
# Forward pass
with torch.no_grad():
preds = model(combined)
preds_np = preds.numpy()
# Inverse transform each property
keys = list(scalers.keys())
preds_rescaled = np.concatenate([
scalers[keys[i]].inverse_transform(preds_np[:, [i]])
for i in range(output_dim)
], axis=1)
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
return results
except Exception as e:
return {"error": str(e)}
# Show function to print the results
def show(smiles: str):
result = predict_properties(smiles)
if "error" in result:
print(f"Error: {result['error']}")
else:
print("Predicted Properties for SMILES:", smiles)
for key, value in result.items():
print(f"{key}: {value}") |