File size: 5,853 Bytes
4b283df
84dad8f
5e9e549
b73e3e2
4b283df
 
84dad8f
76beebb
4b283df
8e77cfc
cf36af6
f8c8eb7
76beebb
cf36af6
 
 
 
 
5834d19
b69e05d
77a4bbb
76beebb
a08c2a4
 
 
76beebb
a08c2a4
b69e05d
76beebb
5834d19
bc443f4
 
 
5834d19
 
bc443f4
5834d19
 
76beebb
5e9e549
a715cd6
5e9e549
a715cd6
3de6f45
76beebb
23348f0
3de6f45
eea9e94
 
3de6f45
5e9e549
3de6f45
eea9e94
3de6f45
5e9e549
 
23348f0
 
 
 
 
 
eea9e94
76beebb
4b283df
 
ca63203
a71233a
8e77cfc
a71233a
76beebb
a71233a
 
4b283df
77a4bbb
76beebb
4b283df
 
 
 
76beebb
4b283df
 
 
 
 
 
 
 
 
 
5834d19
76beebb
4b283df
76beebb
23348f0
af4f3b0
 
 
 
a715cd6
af4f3b0
76beebb
a715cd6
76beebb
5834d19
a715cd6
23348f0
4b283df
76beebb
4b283df
5834d19
4b283df
 
5834d19
4b283df
 
 
 
 
76beebb
4b283df
 
 
 
 
 
 
 
a71233a
76beebb
a71233a
4b283df
 
 
 
 
23348f0
 
 
 
 
5834d19
4b283df
 
 
23348f0
4b283df
23348f0
4b283df
 
23348f0
5834d19
4b283df
 
 
 
 
 
 
 
 
 
 
77a4bbb
 
c615f5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
from datetime import datetime
from db import get_database
import random

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ChemBERTa
@st.cache_resource
def load_chemberta():
    tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()
    return tokenizer, model

# Load scalers
scalers = {
    "Tensile Strength(Mpa)": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
    "Ionization Energy(eV)": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
    "Electron Affinity(eV)": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
    "logP": joblib.load("scaler_LogP.joblib"),
    "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
    "Molecular Weight(g/mol)": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}

# Transformer model
class TransformerRegressor(nn.Module):
    def __init__(self, feat_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
        super().__init__()
        self.feat_proj = nn.Linear(feat_dim, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
            dropout=0.1, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x,feat):
        feat_emb=self.feat_proj(feat)
        stacked=torch.stack([x,feat_emb],dim=1)
        encoded=self.transformer_encoder(stacked)
        aggregated=encoded.mean(dim=1)
        return self.regression_head(aggregated)

# Load model
@st.cache_resource
def load_model():
    model = TransformerRegressor()
    try:
        state_dict = torch.load("transformer_model.bin", map_location=device)
        model.load_state_dict(state_dict)
        model.eval().to(device)
    except Exception as e:
        raise ValueError(f"Failed to load model: {e}")
    return model

# RDKit descriptors
def compute_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    desc = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ]
    return np.array(desc, dtype=np.float32)

# Morgan fingerprint
def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
    return np.array(fp, dtype=np.float32).reshape(1,-1)

# ChemBERTa embedding
def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
    inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = chemberta(**inputs)
    return outputs.last_hidden_state.mean(dim=1).to(device)

# Save to DB
def save_to_db(smiles, predictions):
    predictions_clean = {k: float(v) for k, v in predictions.items()}
    doc = {
        "smiles": smiles,
        "predictions": predictions_clean,
        "timestamp": datetime.now()
    }
    db = get_database()
    db["polymer_predictions"].insert_one(doc)

# Streamlit UI
def show():
    st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
    st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)

    smiles_input = st.text_input("Enter SMILES Representation of Polymer")

    if st.button("Predict"):
        try:
            model = load_model()
            tokenizer, chemberta = load_chemberta()

            mol = Chem.MolFromSmiles(smiles_input)
            if mol is None:
                st.error("Invalid SMILES string.")
                return

            descriptors = compute_descriptors(smiles_input)
            descriptors_tensor=torch.tensor(descriptors,dtype=torch.float32).unsqueeze(0)
            fingerprint = get_morgan_fingerprint(smiles_input)
            fingerprint_tensor=torch.tensor(fingerprint,dtype=torch.float32)
            features=torch.cat([descriptors_tensor,fingerprint_tensor],dim=1).to(device)
            embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)


            with torch.no_grad():
                preds = model(embedding,features)

            preds_np=preds.cpu().numpy()
            keys = list(scalers.keys())
            preds_rescaled = np.concatenate([
                scalers[keys[i]].inverse_transform(preds_np[:, [i]])
                for i in range(6)
            ], axis=1)

            results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}

            st.success("Predicted Properties:")
            for key, val in results.items():
                st.markdown(f"**{key}**: {val}")

            save_to_db(smiles_input, results)

        except Exception as e:
            st.error(f"Prediction failed: {e}")