File size: 3,607 Bytes
84dad8f
5e9e549
c621eb3
84dad8f
5e9e549
c621eb3
5e9e549
c621eb3
5e9e549
c621eb3
5e9e549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84dad8f
c621eb3
5e9e549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import joblib
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
import numpy as np

# Load tokenizer and embedding model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# Load individual scalers
target_keys = [
    "Tensile_strength(Mpa)", 
    "Ionization_Energy(eV)", 
    "Electron_Affinity(eV)", 
    "LogP", 
    "Refractive_Index", 
    "Molecular_Weight(g/mol)"
]
scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]

# Descriptor function (must match training order)
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    return np.array([
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ], dtype=np.float32)

# Model class must match training
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
        super().__init__()
        self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2 * d_model, 256),
            nn.ReLU(),
            nn.Linear(256, num_targets)
        )

    def forward(self, embedding, descriptors):
        desc_proj = self.descriptor_proj(descriptors).unsqueeze(1)  # (B, 1, d_model)
        stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1)  # (B, 2, d_model)
        encoded = self.transformer(stacked)  # (B, 2, d_model)
        return self.regressor(encoded)

# Load trained model
model = TransformerRegressor()
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
model.eval()

# Main prediction function
def predict_properties(smiles):
    try:
        # Compute descriptors
        descriptors = compute_descriptors(smiles)
        descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)  # (1, 10)

        # Get embedding from ChemBERTa
        inputs = tokenizer(smiles, return_tensors="pt")
        with torch.no_grad():
            outputs = embedding_model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]  # (1, 768)

        # Predict
        with torch.no_grad():
            preds = model(embedding, descriptors_tensor)

        # Inverse transform each prediction
        preds_np = preds.numpy().flatten()
        preds_rescaled = [
            scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers))
        ]

        # Prepare results
        readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
        results = dict(zip(readable_keys, np.round(preds_rescaled, 4)))

        return results

    except Exception as e:
        return {"error": str(e)}