File size: 4,089 Bytes
bcfbb05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# transformer_model.py

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
from sklearn.preprocessing import StandardScaler
import numpy as np

# Initialize Tokenizer and Model from ChemBERTa
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta.eval()

# Function to fix SMILES
def fix_smiles(s):
    try:
        mol = Chem.MolFromSmiles(s.strip())
        if mol:
            return Chem.MolToSmiles(mol)
    except:
        pass
    return None

# Function to compute descriptors + fingerprints
def compute_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return [0] * 10 + [0] * 2048
    descriptor_fns = [
        Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
        Descriptors.NumRotatableBonds, Descriptors.NumHDonors,
        Descriptors.NumHAcceptors, Descriptors.FractionCSP3,
        Descriptors.HeavyAtomCount, Descriptors.RingCount, Descriptors.MolMR
    ]
    desc = [fn(mol) for fn in descriptor_fns]
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
    return desc + list(fp)

# Embedding function using ChemBERTa
@torch.no_grad()
def embed_smiles(smiles_list):
    inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
    outputs = chemberta(**inputs)
    return outputs.last_hidden_state[:, 0, :]  # CLS token

# Model Definition (Transformer Regressor)
class TransformerRegressor(nn.Module):
    def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
        super().__init__()
        self.feat_proj = nn.Linear(feat_dim, emb_dim)  # Project features to embedding space

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=1024, dropout=0.1, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)  # Transformer Encoder

        # Regression head
        self.regression_head = nn.Sequential(
            nn.Linear(emb_dim, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x, feat):
        feat_emb = self.feat_proj(feat)  # [B, 768]
        stacked = torch.stack([x, feat_emb], dim=1)  # Stack SMILES embedding and features [B, 2, 768]
        encoded = self.transformer_encoder(stacked)  # Transformer encoding
        aggregated = encoded.mean(dim=1)  # Aggregate encoded sequence
        return self.regression_head(aggregated)  # Regression output

# Ensemble prediction class
class EnsembleModel:
    def __init__(self, model_paths, device):
        self.models = []
        self.device = device
        self.load_models(model_paths)

    def load_models(self, model_paths):
        for path in model_paths:
            model = TransformerRegressor().to(self.device)
            model.load_state_dict(torch.load(path, map_location=self.device))
            model.eval()
            self.models.append(model)

    def predict(self, smiles, features_tensor):
        # Clean and embed SMILES
        cleaned_smiles = fix_smiles(smiles)
        if not cleaned_smiles:
            raise ValueError("Invalid SMILES string.")
        
        # Embed SMILES
        cls_embedding = embed_smiles([cleaned_smiles]).to(self.device)

        # Predict using the ensemble
        preds_all = []
        for model in self.models:
            with torch.no_grad():
                pred = model(cls_embedding, features_tensor)
                preds_all.append(pred)

        # Average the predictions across the models
        preds_ensemble = torch.stack(preds_all).mean(dim=0)
        return preds_ensemble.cpu().numpy()

# Helper function to inverse transform predictions
def inverse_transform_predictions(y_pred, scalers):
    return np.column_stack([scaler.inverse_transform(y_pred[:, i:i+1]) for i, scaler in enumerate(scalers)])