File size: 4,391 Bytes
84dad8f
5e9e549
c621eb3
b73e3e2
84dad8f
5e9e549
c621eb3
 
b73e3e2
c621eb3
5e9e549
 
b73e3e2
 
 
 
 
 
 
 
 
5e9e549
b73e3e2
 
84dad8f
c621eb3
b73e3e2
 
 
5e9e549
 
 
 
 
 
 
 
 
 
b73e3e2
 
5e9e549
3de6f45
5e9e549
3de6f45
5e9e549
3de6f45
 
 
 
 
 
 
 
eea9e94
 
3de6f45
5e9e549
3de6f45
eea9e94
3de6f45
5e9e549
 
eea9e94
 
 
3de6f45
eea9e94
 
b73e3e2
3de6f45
 
 
 
eea9e94
3de6f45
5e9e549
b73e3e2
 
3de6f45
b73e3e2
5e9e549
 
b73e3e2
 
5e9e549
3de6f45
 
 
5e9e549
3de6f45
5e9e549
 
 
3de6f45
 
 
 
5e9e549
3de6f45
5e9e549
3de6f45
5e9e549
15f5470
b73e3e2
 
 
15f5470
b73e3e2
 
15f5470
5e9e549
b73e3e2
5e9e549
 
 
c3f06b5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import joblib
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel

# Load ChemBERTa tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# Load saved scalers for inverse transformations
scalers = {
    "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
    "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
    "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
    "logP": joblib.load("scaler_LogP.joblib"),
    "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
    "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}

# Descriptor calculation
def compute_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    
    descriptors = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ]
    return np.array(descriptors, dtype=np.float32)

# Transformer regression model definition (must match training)
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim):
        super().__init__()
        self.feat_proj = nn.Linear(input_dim, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=8,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.feat_proj(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        return self.regression_head(x)

# Model hyperparameters (must match training)
embedding_dim = 768
descriptor_dim = 1290  # Based on earlier errors. If unsure, use 1290
input_dim = embedding_dim + descriptor_dim  # 768 + 1290 = 2058
ff_dim = 1024
num_layers = 2
output_dim = 6

# Load trained model
device = torch.device("cpu")
model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim)
model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
model.eval()

# Prediction function
def predict_properties(smiles: str):
    try:
        # Compute descriptors
        descriptors = compute_descriptors(smiles)
        descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)

        # Get ChemBERTa embedding (CLS token)
        inputs = tokenizer(smiles, return_tensors="pt")
        with torch.no_grad():
            outputs = embedding_model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]  # (1, 768)

        # Combine features
        combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1)  # Shape: (1, 1, 2058)

        # Forward pass
        with torch.no_grad():
            preds = model(combined)

        preds_np = preds.numpy()

        # Inverse transform each property
        keys = list(scalers.keys())
        preds_rescaled = np.concatenate([
            scalers[keys[i]].inverse_transform(preds_np[:, [i]])
            for i in range(output_dim)
        ], axis=1)

        results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
        return results

    except Exception as e:
        return {"error": str(e)}

# Show function to print the results
def show(smiles: str):
    result = predict_properties(smiles)
    
    if "error" in result:
        print(f"Error: {result['error']}")
    else:
        print("Predicted Properties for SMILES:", smiles)
        for key, value in result.items():
            print(f"{key}: {value}")