Spaces:
Running
Running
# transformer_model.py | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModel | |
from rdkit import Chem | |
from rdkit.Chem import Descriptors, AllChem | |
from sklearn.preprocessing import StandardScaler | |
import numpy as np | |
# Initialize Tokenizer and Model from ChemBERTa | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
chemberta.eval() | |
# Function to fix SMILES | |
def fix_smiles(s): | |
try: | |
mol = Chem.MolFromSmiles(s.strip()) | |
if mol: | |
return Chem.MolToSmiles(mol) | |
except: | |
pass | |
return None | |
# Function to compute descriptors + fingerprints | |
def compute_features(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if not mol: | |
return [0] * 10 + [0] * 2048 | |
descriptor_fns = [ | |
Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, | |
Descriptors.NumRotatableBonds, Descriptors.NumHDonors, | |
Descriptors.NumHAcceptors, Descriptors.FractionCSP3, | |
Descriptors.HeavyAtomCount, Descriptors.RingCount, Descriptors.MolMR | |
] | |
desc = [fn(mol) for fn in descriptor_fns] | |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) | |
return desc + list(fp) | |
# Embedding function using ChemBERTa | |
def embed_smiles(smiles_list): | |
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
outputs = chemberta(**inputs) | |
return outputs.last_hidden_state[:, 0, :] # CLS token | |
# Model Definition (Transformer Regressor) | |
class TransformerRegressor(nn.Module): | |
def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): | |
super().__init__() | |
self.feat_proj = nn.Linear(feat_dim, emb_dim) # Project features to embedding space | |
encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=1024, dropout=0.1, batch_first=True) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Transformer Encoder | |
# Regression head | |
self.regression_head = nn.Sequential( | |
nn.Linear(emb_dim, 256), nn.ReLU(), | |
nn.Linear(256, 128), nn.ReLU(), | |
nn.Linear(128, output_dim) | |
) | |
def forward(self, x, feat): | |
feat_emb = self.feat_proj(feat) # [B, 768] | |
stacked = torch.stack([x, feat_emb], dim=1) # Stack SMILES embedding and features [B, 2, 768] | |
encoded = self.transformer_encoder(stacked) # Transformer encoding | |
aggregated = encoded.mean(dim=1) # Aggregate encoded sequence | |
return self.regression_head(aggregated) # Regression output | |
# Ensemble prediction class | |
class EnsembleModel: | |
def __init__(self, model_paths, device): | |
self.models = [] | |
self.device = device | |
self.load_models(model_paths) | |
def load_models(self, model_paths): | |
for path in model_paths: | |
model = TransformerRegressor().to(self.device) | |
model.load_state_dict(torch.load(path, map_location=self.device)) | |
model.eval() | |
self.models.append(model) | |
def predict(self, smiles, features_tensor): | |
# Clean and embed SMILES | |
cleaned_smiles = fix_smiles(smiles) | |
if not cleaned_smiles: | |
raise ValueError("Invalid SMILES string.") | |
# Embed SMILES | |
cls_embedding = embed_smiles([cleaned_smiles]).to(self.device) | |
# Predict using the ensemble | |
preds_all = [] | |
for model in self.models: | |
with torch.no_grad(): | |
pred = model(cls_embedding, features_tensor) | |
preds_all.append(pred) | |
# Average the predictions across the models | |
preds_ensemble = torch.stack(preds_all).mean(dim=0) | |
return preds_ensemble.cpu().numpy() | |
# Helper function to inverse transform predictions | |
def inverse_transform_predictions(y_pred, scalers): | |
return np.column_stack([scaler.inverse_transform(y_pred[:, i:i+1]) for i, scaler in enumerate(scalers)]) | |