# transformer_model.py import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel from rdkit import Chem from rdkit.Chem import Descriptors, AllChem from sklearn.preprocessing import StandardScaler import numpy as np # Initialize Tokenizer and Model from ChemBERTa tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") chemberta.eval() # Function to fix SMILES def fix_smiles(s): try: mol = Chem.MolFromSmiles(s.strip()) if mol: return Chem.MolToSmiles(mol) except: pass return None # Function to compute descriptors + fingerprints def compute_features(smiles): mol = Chem.MolFromSmiles(smiles) if not mol: return [0] * 10 + [0] * 2048 descriptor_fns = [ Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, Descriptors.NumRotatableBonds, Descriptors.NumHDonors, Descriptors.NumHAcceptors, Descriptors.FractionCSP3, Descriptors.HeavyAtomCount, Descriptors.RingCount, Descriptors.MolMR ] desc = [fn(mol) for fn in descriptor_fns] fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) return desc + list(fp) # Embedding function using ChemBERTa @torch.no_grad() def embed_smiles(smiles_list): inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) outputs = chemberta(**inputs) return outputs.last_hidden_state[:, 0, :] # CLS token # Model Definition (Transformer Regressor) class TransformerRegressor(nn.Module): def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): super().__init__() self.feat_proj = nn.Linear(feat_dim, emb_dim) # Project features to embedding space encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=1024, dropout=0.1, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Transformer Encoder # Regression head self.regression_head = nn.Sequential( nn.Linear(emb_dim, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, output_dim) ) def forward(self, x, feat): feat_emb = self.feat_proj(feat) # [B, 768] stacked = torch.stack([x, feat_emb], dim=1) # Stack SMILES embedding and features [B, 2, 768] encoded = self.transformer_encoder(stacked) # Transformer encoding aggregated = encoded.mean(dim=1) # Aggregate encoded sequence return self.regression_head(aggregated) # Regression output # Ensemble prediction class class EnsembleModel: def __init__(self, model_paths, device): self.models = [] self.device = device self.load_models(model_paths) def load_models(self, model_paths): for path in model_paths: model = TransformerRegressor().to(self.device) model.load_state_dict(torch.load(path, map_location=self.device)) model.eval() self.models.append(model) def predict(self, smiles, features_tensor): # Clean and embed SMILES cleaned_smiles = fix_smiles(smiles) if not cleaned_smiles: raise ValueError("Invalid SMILES string.") # Embed SMILES cls_embedding = embed_smiles([cleaned_smiles]).to(self.device) # Predict using the ensemble preds_all = [] for model in self.models: with torch.no_grad(): pred = model(cls_embedding, features_tensor) preds_all.append(pred) # Average the predictions across the models preds_ensemble = torch.stack(preds_all).mean(dim=0) return preds_ensemble.cpu().numpy() # Helper function to inverse transform predictions def inverse_transform_predictions(y_pred, scalers): return np.column_stack([scaler.inverse_transform(y_pred[:, i:i+1]) for i, scaler in enumerate(scalers)])