transpolymer commited on
Commit
3c1e1fb
·
verified ·
1 Parent(s): bf11015

Delete transformer_model.py

Browse files
Files changed (1) hide show
  1. transformer_model.py +0 -108
transformer_model.py DELETED
@@ -1,108 +0,0 @@
1
- # transformer_model.py
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoTokenizer, AutoModel
6
- from rdkit import Chem
7
- from rdkit.Chem import Descriptors, AllChem
8
- from sklearn.preprocessing import StandardScaler
9
- import numpy as np
10
-
11
- # Initialize Tokenizer and Model from ChemBERTa
12
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
- chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
14
- chemberta.eval()
15
-
16
- # Function to fix SMILES
17
- def fix_smiles(s):
18
- try:
19
- mol = Chem.MolFromSmiles(s.strip())
20
- if mol:
21
- return Chem.MolToSmiles(mol)
22
- except:
23
- pass
24
- return None
25
-
26
- # Function to compute descriptors + fingerprints
27
- def compute_features(smiles):
28
- mol = Chem.MolFromSmiles(smiles)
29
- if not mol:
30
- return [0] * 10 + [0] * 2048
31
- descriptor_fns = [
32
- Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
33
- Descriptors.NumRotatableBonds, Descriptors.NumHDonors,
34
- Descriptors.NumHAcceptors, Descriptors.FractionCSP3,
35
- Descriptors.HeavyAtomCount, Descriptors.RingCount, Descriptors.MolMR
36
- ]
37
- desc = [fn(mol) for fn in descriptor_fns]
38
- fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
39
- return desc + list(fp)
40
-
41
- # Embedding function using ChemBERTa
42
- @torch.no_grad()
43
- def embed_smiles(smiles_list):
44
- inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
45
- outputs = chemberta(**inputs)
46
- return outputs.last_hidden_state[:, 0, :] # CLS token
47
-
48
- # Model Definition (Transformer Regressor)
49
- class TransformerRegressor(nn.Module):
50
- def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
51
- super().__init__()
52
- self.feat_proj = nn.Linear(feat_dim, emb_dim) # Project features to embedding space
53
-
54
- encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=1024, dropout=0.1, batch_first=True)
55
- self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Transformer Encoder
56
-
57
- # Regression head
58
- self.regression_head = nn.Sequential(
59
- nn.Linear(emb_dim, 256), nn.ReLU(),
60
- nn.Linear(256, 128), nn.ReLU(),
61
- nn.Linear(128, output_dim)
62
- )
63
-
64
- def forward(self, x, feat):
65
- feat_emb = self.feat_proj(feat) # [B, 768]
66
- stacked = torch.stack([x, feat_emb], dim=1) # Stack SMILES embedding and features [B, 2, 768]
67
- encoded = self.transformer_encoder(stacked) # Transformer encoding
68
- aggregated = encoded.mean(dim=1) # Aggregate encoded sequence
69
- return self.regression_head(aggregated) # Regression output
70
-
71
- # Ensemble prediction class
72
- class EnsembleModel:
73
- def __init__(self, model_paths, device):
74
- self.models = []
75
- self.device = device
76
- self.load_models(model_paths)
77
-
78
- def load_models(self, model_paths):
79
- for path in model_paths:
80
- model = TransformerRegressor().to(self.device)
81
- model.load_state_dict(torch.load(path, map_location=self.device))
82
- model.eval()
83
- self.models.append(model)
84
-
85
- def predict(self, smiles, features_tensor):
86
- # Clean and embed SMILES
87
- cleaned_smiles = fix_smiles(smiles)
88
- if not cleaned_smiles:
89
- raise ValueError("Invalid SMILES string.")
90
-
91
- # Embed SMILES
92
- cls_embedding = embed_smiles([cleaned_smiles]).to(self.device)
93
-
94
- # Predict using the ensemble
95
- preds_all = []
96
- for model in self.models:
97
- with torch.no_grad():
98
- pred = model(cls_embedding, features_tensor)
99
- preds_all.append(pred)
100
-
101
- # Average the predictions across the models
102
- preds_ensemble = torch.stack(preds_all).mean(dim=0)
103
- return preds_ensemble.cpu().numpy()
104
-
105
- # Helper function to inverse transform predictions
106
- def inverse_transform_predictions(y_pred, scalers):
107
- return np.column_stack([scaler.inverse_transform(y_pred[:, i:i+1]) for i, scaler in enumerate(scalers)])
108
-