File size: 6,536 Bytes
4b283df
84dad8f
5e9e549
b73e3e2
4b283df
 
84dad8f
5e9e549
af4f3b0
4b283df
 
cf36af6
c621eb3
cf36af6
 
 
 
 
 
b69e05d
 
cf36af6
4b283df
 
 
 
 
3b0f51a
4b283df
b69e05d
4b283df
cf36af6
b73e3e2
 
 
 
 
 
 
 
5e9e549
cf36af6
5e9e549
4b283df
5e9e549
3de6f45
 
 
 
 
cf36af6
3de6f45
 
eea9e94
 
3de6f45
5e9e549
3de6f45
eea9e94
3de6f45
5e9e549
 
eea9e94
 
 
3de6f45
eea9e94
 
4b283df
 
ca63203
 
 
 
3b0f51a
ca63203
 
 
 
 
4b283df
3b0f51a
4b283df
cf36af6
4b283df
 
 
 
c3f06b5
4b283df
 
 
 
 
 
 
 
 
 
 
 
 
 
cf36af6
af4f3b0
 
 
 
 
cf36af6
af4f3b0
cf36af6
4b283df
 
 
 
cf36af6
4b283df
cf36af6
4b283df
 
 
 
 
 
 
 
 
 
cf36af6
4b283df
 
 
 
 
 
 
 
 
 
 
 
 
 
cf36af6
af4f3b0
cf36af6
 
4b283df
cf36af6
4b283df
cf36af6
 
4b283df
 
 
 
 
 
cf36af6
4b283df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from datetime import datetime
from db import get_database  # This must be available in your repo
import random

# ------------------------ Ensuring Deterministic Behavior ------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Check if CUDA is available for GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
@st.cache_resource
def load_chemberta():
    tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model.eval()
    model.to(device)  # Send model to GPU if available
    return tokenizer, model


# ------------------------ Load Scalers ------------------------
scalers = {
    "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
    "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
    "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
    "logP": joblib.load("scaler_LogP.joblib"),
    "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
    "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}

# ------------------------ Transformer Model ------------------------
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
        super().__init__()
        self.feat_proj = nn.Linear(input_dim, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=8,
            dim_feedforward=ff_dim,
            dropout=0.0,  # No dropout for consistency
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.feat_proj(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        return self.regression_head(x)

@st.cache_resource
def load_model():
    # Initialize the model architecture first
    model = TransformerRegressor()

    # Load the state_dict (weights) from the saved model file
    state_dict = torch.load("transformer_model.bin", map_location=device)  # Ensure loading on the correct device

    # Load the state_dict into the model
    model.load_state_dict(state_dict)

    # Set the model to evaluation mode
    model.eval()
    model.to(device)  # Send model to GPU if available
    return model
# ------------------------ Descriptors ------------------------
def compute_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    
    descriptors = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ]
    return np.array(descriptors, dtype=np.float32)

# ------------------------ Fingerprints ------------------------
def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
    return np.array(fp, dtype=np.float32).reshape(1, -1)

# ------------------------ Embedding ------------------------
def get_chemberta_embedding(smiles: str):
    inputs = tokenizer(smiles, return_tensors="pt")
    with torch.no_grad():
        outputs = chemberta(**inputs)
    return outputs.last_hidden_state.mean(dim=1)  # Use average instead of CLS token

# ------------------------ Save to DB ------------------------
def save_to_db(smiles, predictions):
    predictions_clean = {k: float(v) for k, v in predictions.items()}
    doc = {
        "smiles": smiles,
        "predictions": predictions_clean,
        "timestamp": datetime.now()
    }
    db = get_database()
    db["polymer_predictions"].insert_one(doc)

# ------------------------ Streamlit App ------------------------
def show():
    st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
    st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)

    smiles_input = st.text_input("Enter SMILES Representation of Polymer")

    if st.button("Predict"):
        try:
            mol = Chem.MolFromSmiles(smiles_input)
            if mol is None:
                st.error("Invalid SMILES string.")
                return

            descriptors = compute_descriptors(smiles_input)
            descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)

            fingerprint = get_morgan_fingerprint(smiles_input)
            fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32)

            embedding = get_chemberta_embedding(smiles_input)

            combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
            combined = combined_input.unsqueeze(1)

            with torch.no_grad():
                preds = model(combined)

            preds_np = preds.numpy()
            keys = list(scalers.keys())

            preds_rescaled = np.concatenate([
                scalers[keys[i]].inverse_transform(preds_np[:, [i]])
                for i in range(6)
            ], axis=1)

            results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}

            st.success("Predicted Properties:")
            for key, val in results.items():
                st.markdown(f"**{key}**: {val}")

            save_to_db(smiles_input, results)

        except Exception as e:
            st.error(f"Prediction failed: {e}")