File size: 5,018 Bytes
4b283df
84dad8f
5e9e549
b73e3e2
4b283df
 
84dad8f
5e9e549
4b283df
 
c621eb3
4b283df
 
 
 
 
 
 
5e9e549
4b283df
 
 
b73e3e2
 
 
 
 
 
 
 
5e9e549
4b283df
5e9e549
4b283df
5e9e549
3de6f45
 
 
 
 
 
 
 
eea9e94
 
3de6f45
5e9e549
3de6f45
eea9e94
3de6f45
5e9e549
 
eea9e94
 
 
3de6f45
eea9e94
 
4b283df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3f06b5
4b283df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors
from datetime import datetime
from db import get_database  # This must be available in your repo

# Load ChemBERTa model + tokenizer
@st.cache_resource
def load_chemberta():
    tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model.eval()
    return tokenizer, model

tokenizer, chemberta = load_chemberta()

# Load scalers
scalers = {
    "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
    "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
    "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
    "logP": joblib.load("scaler_LogP.joblib"),
    "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
    "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}

# Model Definition
class TransformerRegressor(nn.Module):
    def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
        super().__init__()
        self.feat_proj = nn.Linear(input_dim, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=8,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.feat_proj(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        return self.regression_head(x)

# Load model
@st.cache_resource
def load_model():
    model = TransformerRegressor()
    model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
    model.eval()
    return model

model = load_model()

# Descriptor computation
def compute_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    
    descriptors = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ]
    return np.array(descriptors, dtype=np.float32)

# Embedding function
def get_chemberta_embedding(smiles: str):
    inputs = tokenizer(smiles, return_tensors="pt")
    with torch.no_grad():
        outputs = chemberta(**inputs)
    return outputs.last_hidden_state[:, 0, :]  # CLS token

# Save prediction to MongoDB
def save_to_db(smiles, predictions):
    predictions_clean = {k: float(v) for k, v in predictions.items()}
    doc = {
        "smiles": smiles,
        "predictions": predictions_clean,
        "timestamp": datetime.now()
    }
    db = get_database()
    db["polymer_predictions"].insert_one(doc)

# Main Streamlit UI + prediction
def show():
    st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
    st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)

    smiles_input = st.text_input("Enter SMILES Representation of Polymer")

    if st.button("Predict"):
        try:
            mol = Chem.MolFromSmiles(smiles_input)
            if mol is None:
                st.error("Invalid SMILES string.")
                return

            descriptors = compute_descriptors(smiles_input)
            descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)

            embedding = get_chemberta_embedding(smiles_input)

            combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1)  # (1, 1, 2058)

            with torch.no_grad():
                preds = model(combined)

            preds_np = preds.numpy()

            keys = list(scalers.keys())
            preds_rescaled = np.concatenate([
                scalers[keys[i]].inverse_transform(preds_np[:, [i]])
                for i in range(6)
            ], axis=1)

            results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}

            # Display results
            st.success("Predicted Properties:")
            for key, val in results.items():
                st.markdown(f"**{key}**: {val}")

            # Save to MongoDB
            save_to_db(smiles_input, results)

        except Exception as e:
            st.error(f"Prediction failed: {e}")