File size: 7,033 Bytes
4b283df
84dad8f
5e9e549
b73e3e2
4b283df
 
84dad8f
5834d19
 
4b283df
5834d19
cf36af6
c621eb3
6c37c0f
f8c8eb7
cf36af6
 
 
 
 
 
5834d19
 
b69e05d
77a4bbb
cf36af6
a08c2a4
 
 
 
 
5834d19
a08c2a4
b69e05d
5834d19
 
 
 
 
 
 
 
 
 
 
5e9e549
4b283df
5e9e549
3de6f45
 
5834d19
 
 
 
3de6f45
 
eea9e94
 
3de6f45
5e9e549
3de6f45
eea9e94
3de6f45
5e9e549
 
eea9e94
 
 
3de6f45
eea9e94
 
5834d19
4b283df
 
5834d19
ca63203
5834d19
 
a71233a
23e592f
a71233a
 
 
 
 
5834d19
4b283df
77a4bbb
cf36af6
4b283df
 
 
 
5834d19
 
4b283df
 
 
 
 
 
 
 
 
 
5834d19
 
4b283df
cf36af6
af4f3b0
 
 
 
 
cf36af6
af4f3b0
5834d19
 
 
 
 
 
4b283df
cf36af6
4b283df
5834d19
4b283df
 
5834d19
4b283df
 
 
 
 
5834d19
4b283df
 
 
 
 
 
 
 
a71233a
 
 
4b283df
 
 
 
 
5834d19
 
 
 
4b283df
cf36af6
af4f3b0
cf36af6
 
4b283df
5834d19
4b283df
5834d19
cf36af6
5834d19
4b283df
 
5834d19
4b283df
5834d19
4b283df
5834d19
 
4b283df
5834d19
 
4b283df
 
 
 
 
 
 
 
5834d19
4b283df
 
 
77a4bbb
 
c615f5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from datetime import datetime
from db import get_database  # Ensure this module is available
import random

 

# ------------------------ Ensuring Deterministic Behavior ------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Check if CUDA is available for GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
@st.cache_resource
def load_chemberta():
    tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model.eval()
    model.to(device)  # Send model to GPU if available
    return tokenizer, model

# ------------------------ Load Scalers ------------------------
scalers = {
    "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
    "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
    "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
    "logP": joblib.load("scaler_LogP.joblib"),
    "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
    "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}

# ------------------------ Transformer Model ------------------------
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
        super().__init__()
        self.feat_proj = nn.Linear(input_dim, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=8,
            dim_feedforward=ff_dim,
            dropout=0.0,  # No dropout for consistency
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.feat_proj(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        return self.regression_head(x)

# ------------------------ Load Model ------------------------
@st.cache_resource
def load_model():
    # Initialize the model architecture first
    model = TransformerRegressor()

    # Load the state_dict (weights) from the saved model file
    try:
       state_dict = torch.load("transformer_model.bin", map_location=device) # Ensure loading on the correct device
        model.load_state_dict(state_dict)
        model.eval()
        model.to(device)  # Send model to GPU if available
    except Exception as e:
        raise ValueError(f"Failed to load model: {e}")

    return model

# ------------------------ Descriptors ------------------------
def compute_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    
    descriptors = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ]
    return np.array(descriptors, dtype=np.float32)

# ------------------------ Fingerprints ------------------------
def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
    return np.array(fp, dtype=np.float32).reshape(1, -1)

# ------------------------ Embedding ------------------------
def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
    inputs = tokenizer(smiles, return_tensors="pt")
    with torch.no_grad():
        outputs = chemberta(**inputs)
    return outputs.last_hidden_state.mean(dim=1)  # Use average instead of CLS token

# ------------------------ Save to DB ------------------------
def save_to_db(smiles, predictions):
    predictions_clean = {k: float(v) for k, v in predictions.items()}
    doc = {
        "smiles": smiles,
        "predictions": predictions_clean,
        "timestamp": datetime.now()
    }
    db = get_database()
    db["polymer_predictions"].insert_one(doc)

# ------------------------ Streamlit App ------------------------
def show():
    st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
    st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)

    smiles_input = st.text_input("Enter SMILES Representation of Polymer")

    if st.button("Predict"):
        try:
            # Load the model
            model = load_model()

            mol = Chem.MolFromSmiles(smiles_input)
            if mol is None:
                st.error("Invalid SMILES string.")
                return

            # Load the ChemBERTa tokenizer and model
            tokenizer, chemberta = load_chemberta()

            # Compute Descriptors, Fingerprints, and Embedding
            descriptors = compute_descriptors(smiles_input)
            descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)

            fingerprint = get_morgan_fingerprint(smiles_input)
            fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32)

            embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)

            # Combine Inputs and Make Prediction
            combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
            combined = combined_input.unsqueeze(1)

            with torch.no_grad():
                preds = model(combined)

            preds_np = preds.numpy()
            keys = list(scalers.keys())

            # Rescale Predictions
            preds_rescaled = np.concatenate([
                scalers[keys[i]].inverse_transform(preds_np[:, [i]])
                for i in range(6)
            ], axis=1)

            results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}

            st.success("Predicted Properties:")
            for key, val in results.items():
                st.markdown(f"**{key}**: {val}")

            # Save the results to the database
            save_to_db(smiles_input, results)

        except Exception as e:
            st.error(f"Prediction failed: {e}")