File size: 3,937 Bytes
84dad8f
5e9e549
c621eb3
84dad8f
5e9e549
c621eb3
5e9e549
c621eb3
15f5470
c621eb3
5e9e549
 
15f5470
eea9e94
 
 
 
 
 
5e9e549
15f5470
5e9e549
84dad8f
c621eb3
15f5470
 
5e9e549
 
 
 
 
 
 
 
 
 
 
 
 
15f5470
5e9e549
eea9e94
5e9e549
eea9e94
 
 
 
 
5e9e549
eea9e94
 
 
5e9e549
 
eea9e94
 
 
 
 
 
 
 
 
 
 
5e9e549
15f5470
eea9e94
5e9e549
 
 
 
 
 
 
eea9e94
5e9e549
15f5470
5e9e549
 
 
eea9e94
5e9e549
15f5470
5e9e549
eea9e94
5e9e549
15f5470
 
 
 
 
 
 
 
 
5e9e549
15f5470
 
5e9e549
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np

# Load tokenizer and model for embeddings
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# Load saved scalers (for inverse_transform)
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib")
scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib")
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib")
scaler_logp = joblib.load("scaler_LogP.joblib")
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib")
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib")

# Descriptor function with exact order from training
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES")

    return np.array([
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ], dtype=np.float32)

# Define your model class exactly like in training
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super().__init__()
        self.feat_proj = nn.Linear(input_dim, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        x = self.feat_proj(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        return self.regression_head(x)

# Set model hyperparameters (must match training config)
input_dim = 768  # ChemBERTa embedding size
hidden_dim = 256
num_layers = 2
output_dim = 6   # Number of properties predicted

# Load model
model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()

# Main prediction function
def predict_properties(smiles):
    try:
        descriptors = compute_descriptors(smiles)
        descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)

        # Get embedding
        inputs = tokenizer(smiles, return_tensors="pt")
        with torch.no_grad():
            outputs = embedding_model(**inputs)
        emb = outputs.last_hidden_state[:, 0, :]  # CLS token output (1, 768)

        # Forward pass
        with torch.no_grad():
            preds = model(emb)

        preds_np = preds.numpy()
        preds_rescaled = np.concatenate([
            scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
            scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
            scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
            scaler_logp.inverse_transform(preds_np[:, [3]]),
            scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
            scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
        ], axis=1)

        keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
        results = dict(zip(keys, preds_rescaled.flatten().round(4)))

        return results

    except Exception as e:
        return {"error": str(e)}