File size: 4,223 Bytes
84dad8f
5e9e549
c621eb3
84dad8f
5e9e549
c621eb3
5e9e549
c621eb3
15f5470
c621eb3
5e9e549
 
15f5470
 
 
 
 
 
 
5e9e549
15f5470
5e9e549
84dad8f
c621eb3
15f5470
 
5e9e549
 
 
 
 
 
 
 
 
 
 
 
 
15f5470
5e9e549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15f5470
 
5e9e549
15f5470
5e9e549
 
 
 
 
 
 
 
 
 
15f5470
5e9e549
 
 
15f5470
5e9e549
15f5470
5e9e549
15f5470
5e9e549
15f5470
 
 
 
 
 
 
 
 
 
5e9e549
15f5470
 
 
5e9e549
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np

# Load tokenizer and model for embeddings
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# Load saved scalers (for inverse_transform)
scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib")  # Scaler for Tensile Strength
scaler_ionization_energy = joblib.load("scaler_lonization_Energy_eV_.joblib")  # Scaler for Ionization Energy
scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib")  # Scaler for Electron Affinity
scaler_logp = joblib.load("scaler_LogP.joblib")  # Scaler for LogP
scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib")  # Scaler for Refractive Index
scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib")  # Scaler for Molecular Weight

# Descriptor function with exact order from training
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES")

    return np.array([
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ], dtype=np.float32)

# Define your model class exactly like in training
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
        super().__init__()
        self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2 * d_model, 256),
            nn.ReLU(),
            nn.Linear(256, num_targets)
        )

    def forward(self, embedding, descriptors):
        desc_proj = self.descriptor_proj(descriptors).unsqueeze(1)  # (B, 1, d_model)
        stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1)  # (B, 2, d_model)
        encoded = self.transformer(stacked)  # (B, 2, d_model)
        output = self.regressor(encoded)
        return output

# Load model
model = TransformerRegressor()
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()

# Main prediction function
def predict_properties(smiles):
    try:
        descriptors = compute_descriptors(smiles)
        descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)  # (1, 10)

        # Get embedding
        inputs = tokenizer(smiles, return_tensors="pt")
        with torch.no_grad():
            outputs = embedding_model(**inputs)
        emb = outputs.last_hidden_state[:, 0, :]  # [CLS] token, shape (1, 768)

        # Forward pass
        with torch.no_grad():
            preds = model(emb, descriptors_tensor)

        # Inverse transform predictions using respective scalers
        preds_np = preds.numpy()
        preds_rescaled = np.concatenate([
            scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
            scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
            scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
            scaler_logp.inverse_transform(preds_np[:, [3]]),
            scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
            scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
        ], axis=1)

        # Round and format
        keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
        results = dict(zip(keys, preds_rescaled.flatten().round(4)))

        return results

    except Exception as e:
        return {"error": str(e)}