transpolymer commited on
Commit
bcfbb05
·
verified ·
1 Parent(s): 833f4a5

Create transformer_model.py

Browse files
Files changed (1) hide show
  1. transformer_model.py +108 -0
transformer_model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # transformer_model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Descriptors, AllChem
8
+ from sklearn.preprocessing import StandardScaler
9
+ import numpy as np
10
+
11
+ # Initialize Tokenizer and Model from ChemBERTa
12
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
+ chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
14
+ chemberta.eval()
15
+
16
+ # Function to fix SMILES
17
+ def fix_smiles(s):
18
+ try:
19
+ mol = Chem.MolFromSmiles(s.strip())
20
+ if mol:
21
+ return Chem.MolToSmiles(mol)
22
+ except:
23
+ pass
24
+ return None
25
+
26
+ # Function to compute descriptors + fingerprints
27
+ def compute_features(smiles):
28
+ mol = Chem.MolFromSmiles(smiles)
29
+ if not mol:
30
+ return [0] * 10 + [0] * 2048
31
+ descriptor_fns = [
32
+ Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
33
+ Descriptors.NumRotatableBonds, Descriptors.NumHDonors,
34
+ Descriptors.NumHAcceptors, Descriptors.FractionCSP3,
35
+ Descriptors.HeavyAtomCount, Descriptors.RingCount, Descriptors.MolMR
36
+ ]
37
+ desc = [fn(mol) for fn in descriptor_fns]
38
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
39
+ return desc + list(fp)
40
+
41
+ # Embedding function using ChemBERTa
42
+ @torch.no_grad()
43
+ def embed_smiles(smiles_list):
44
+ inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
45
+ outputs = chemberta(**inputs)
46
+ return outputs.last_hidden_state[:, 0, :] # CLS token
47
+
48
+ # Model Definition (Transformer Regressor)
49
+ class TransformerRegressor(nn.Module):
50
+ def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
51
+ super().__init__()
52
+ self.feat_proj = nn.Linear(feat_dim, emb_dim) # Project features to embedding space
53
+
54
+ encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=1024, dropout=0.1, batch_first=True)
55
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Transformer Encoder
56
+
57
+ # Regression head
58
+ self.regression_head = nn.Sequential(
59
+ nn.Linear(emb_dim, 256), nn.ReLU(),
60
+ nn.Linear(256, 128), nn.ReLU(),
61
+ nn.Linear(128, output_dim)
62
+ )
63
+
64
+ def forward(self, x, feat):
65
+ feat_emb = self.feat_proj(feat) # [B, 768]
66
+ stacked = torch.stack([x, feat_emb], dim=1) # Stack SMILES embedding and features [B, 2, 768]
67
+ encoded = self.transformer_encoder(stacked) # Transformer encoding
68
+ aggregated = encoded.mean(dim=1) # Aggregate encoded sequence
69
+ return self.regression_head(aggregated) # Regression output
70
+
71
+ # Ensemble prediction class
72
+ class EnsembleModel:
73
+ def __init__(self, model_paths, device):
74
+ self.models = []
75
+ self.device = device
76
+ self.load_models(model_paths)
77
+
78
+ def load_models(self, model_paths):
79
+ for path in model_paths:
80
+ model = TransformerRegressor().to(self.device)
81
+ model.load_state_dict(torch.load(path, map_location=self.device))
82
+ model.eval()
83
+ self.models.append(model)
84
+
85
+ def predict(self, smiles, features_tensor):
86
+ # Clean and embed SMILES
87
+ cleaned_smiles = fix_smiles(smiles)
88
+ if not cleaned_smiles:
89
+ raise ValueError("Invalid SMILES string.")
90
+
91
+ # Embed SMILES
92
+ cls_embedding = embed_smiles([cleaned_smiles]).to(self.device)
93
+
94
+ # Predict using the ensemble
95
+ preds_all = []
96
+ for model in self.models:
97
+ with torch.no_grad():
98
+ pred = model(cls_embedding, features_tensor)
99
+ preds_all.append(pred)
100
+
101
+ # Average the predictions across the models
102
+ preds_ensemble = torch.stack(preds_all).mean(dim=0)
103
+ return preds_ensemble.cpu().numpy()
104
+
105
+ # Helper function to inverse transform predictions
106
+ def inverse_transform_predictions(y_pred, scalers):
107
+ return np.column_stack([scaler.inverse_transform(y_pred[:, i:i+1]) for i, scaler in enumerate(scalers)])
108
+