File size: 17,727 Bytes
4b283df
84dad8f
5e9e549
b73e3e2
4b283df
 
84dad8f
5909561
4b283df
8e77cfc
cf36af6
5909561
 
 
 
f8c8eb7
76beebb
cf36af6
 
 
 
 
5834d19
b69e05d
77a4bbb
5909561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76beebb
a08c2a4
 
5909561
 
 
a08c2a4
b69e05d
76beebb
5909561
 
 
 
 
 
 
 
 
 
5834d19
76beebb
5e9e549
a715cd6
5e9e549
a715cd6
3de6f45
76beebb
23348f0
3de6f45
eea9e94
 
3de6f45
5e9e549
3de6f45
eea9e94
3de6f45
5e9e549
 
5909561
 
 
 
 
23348f0
eea9e94
76beebb
4b283df
 
5909561
 
 
 
 
 
 
 
4b283df
77a4bbb
76beebb
4b283df
 
 
 
76beebb
4b283df
 
 
 
 
 
 
 
 
 
5834d19
76beebb
4b283df
76beebb
23348f0
af4f3b0
 
 
 
5909561
af4f3b0
76beebb
a715cd6
76beebb
5909561
5834d19
a715cd6
5909561
4b283df
76beebb
5909561
5834d19
4b283df
 
5834d19
4b283df
 
5909561
 
 
4b283df
 
5909561
4b283df
5909561
 
 
 
 
 
 
 
 
 
8a98757
8ae7b7c
5909561
 
 
 
 
4b283df
5909561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b283df
8a98757
8ae7b7c
5909561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3846070
 
5909561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8e3d22
5909561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b283df
5909561
 
 
 
 
 
a71233a
76beebb
5909561
3846070
5909561
 
 
 
23348f0
5909561
23348f0
5909561
 
3846070
5909561
 
 
 
5834d19
3846070
5909561
 
 
 
4b283df
5909561
 
 
4b283df
 
23348f0
5834d19
4b283df
5909561
 
3846070
5909561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b283df
 
5909561
 
 
 
 
 
 
 
 
77a4bbb
5909561
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem, Draw
from datetime import datetime
from db import get_database
import random
import pandas as pd
import time
import base64
from io import BytesIO

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Page styling and configuration
st.set_page_config(
    page_title="Polymer Property Prediction",
    page_icon="πŸ§ͺ",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        font-weight: 700;
        color: #4CAF50;
        text-align: center;
        margin-bottom: 1rem;
        background: linear-gradient(90deg, #f8f9fa 0%, #e9ecef 100%);
        padding: 1.5rem 0;
        border-radius: 10px;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    }
    .sub-header {
        font-size: 1.5rem;
        font-weight: 600;
        color: #2E7D32;
        margin-bottom: 0.5rem;
    }
    .property-card {
        background-color: #f1f8e9;
        border-radius: 10px;
        padding: 1rem;
        margin: 0.5rem 0;
        box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
        transition: transform 0.3s ease;
    }
    .property-card:hover {
        transform: translateY(-5px);
        box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
    }
    .loader {
        border: 16px solid #f3f3f3;
        border-radius: 50%;
        border-top: 16px solid #3498db;
        width: 50px;
        height: 50px;
        animation: spin 2s linear infinite;
        margin: 20px auto;
    }
    .info-box {
        background-color: #e3f2fd;
        border-left: 5px solid #2196f3;
        padding: 1rem;
        margin: 1rem 0;
        border-radius: 5px;
    }
    .tooltip {
        position: relative;
        display: inline-block;
        border-bottom: 1px dotted black;
    }
    .tooltip .tooltiptext {
        visibility: hidden;
        width: 120px;
        background-color: black;
        color: #fff;
        text-align: center;
        border-radius: 6px;
        padding: 5px 0;
        position: absolute;
        z-index: 1;
        bottom: 125%;
        left: 50%;
        margin-left: -60px;
        opacity: 0;
        transition: opacity 0.3s;
    }
    .tooltip:hover .tooltiptext {
        visibility: visible;
        opacity: 1;
    }
    @keyframes spin {
        0% { transform: rotate(0deg); }
        100% { transform: rotate(360deg); }
    }
    .stProgress > div > div > div > div {
        background-color: #4CAF50 !important;
    }
</style>
""", unsafe_allow_html=True)

# Load ChemBERTa
@st.cache_resource
def load_chemberta():
    with st.spinner("Loading ChemBERTa model..."):
        tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()
    return tokenizer, model

# Load scalers
@st.cache_resource
def load_scalers():
    return {
        "Tensile Strength (MPa)": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
        "Ionization Energy (eV)": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
        "Electron Affinity (eV)": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
        "logP": joblib.load("scaler_LogP.joblib"),
        "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
        "Molecular Weight (g/mol)": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
    }

# Transformer model
class TransformerRegressor(nn.Module):
    def __init__(self, feat_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
        super().__init__()
        self.feat_proj = nn.Linear(feat_dim, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
            dropout=0.1, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x, feat):
        feat_emb = self.feat_proj(feat)
        stacked = torch.stack([x, feat_emb], dim=1)
        encoded = self.transformer_encoder(stacked)
        aggregated = encoded.mean(dim=1)
        return self.regression_head(aggregated)

# Load model
@st.cache_resource
def load_model():
    with st.spinner("Loading prediction model..."):
        model = TransformerRegressor()
        try:
            state_dict = torch.load("transformer_model.bin", map_location=device)
            model.load_state_dict(state_dict)
            model.eval().to(device)
        except Exception as e:
            raise ValueError(f"Failed to load model: {e}")
    return model

# RDKit descriptors
def compute_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    desc = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        Descriptors.RingCount(mol),
        Descriptors.MolMR(mol)
    ]
    return np.array(desc, dtype=np.float32)

# Morgan fingerprint
def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string.")
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
    return np.array(fp, dtype=np.float32).reshape(1, -1)

# ChemBERTa embedding
def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
    inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = chemberta(**inputs)
    return outputs.last_hidden_state.mean(dim=1)

# Save to DB
def save_to_db(smiles, predictions, mol_image=None):
    predictions_clean = {k: float(v) for k, v in predictions.items()}
    doc = {
        "smiles": smiles,
        "predictions": predictions_clean,
        "timestamp": datetime.now()
    }
    if mol_image:
        doc["molecule_image"] = mol_image
    
    db = get_database()
    db["polymer_predictions"].insert_one(doc)
    return doc["_id"]

# Get molecule image as base64
def get_molecule_image(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        img = Draw.MolToImage(mol, size=(300, 300))
        buffered = BytesIO()
        img.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode()
    return None

# Removed example SMILES

# Get history from database
def get_prediction_history(limit=5):
    db = get_database()
    history = list(db["polymer_predictions"].find().sort("timestamp", -1).limit(limit))
    return history

# Sidebar
def show_sidebar():
    st.sidebar.markdown("<div class='sub-header'>About This Tool</div>", unsafe_allow_html=True)
    st.sidebar.info("""
    This tool predicts key properties of polymers based on their SMILES representation.
    
    It uses a transformer neural network combined with ChemBERTa embeddings and molecular descriptors.
    """)
    
    st.sidebar.markdown("<div class='sub-header'>Property Explanations</div>", unsafe_allow_html=True)
    
    with st.sidebar.expander("Tensile Strength"):
        st.write("""
        **Tensile Strength (MPa)** measures the maximum stress a material can withstand before breaking.
        Higher values indicate stronger materials.
        """)
    
    with st.sidebar.expander("Ionization Energy"):
        st.write("""
        **Ionization Energy (eV)** is the energy required to remove an electron from an atom or molecule.
        It affects chemical reactivity and stability.
        """)
        
    with st.sidebar.expander("Electron Affinity"):
        st.write("""
        **Electron Affinity (eV)** measures how much energy is released when an electron is added to a neutral atom.
        It influences a polymer's electrical properties.
        """)
        
    with st.sidebar.expander("logP"):
        st.write("""
        **logP** is the partition coefficient that measures how a substance distributes between water and lipid phases.
        It affects solubility and permeability of polymers.
        """)
        
    with st.sidebar.expander("Refractive Index"):
        st.write("""
        **Refractive Index** measures how light propagates through the material.
        It's important for optical applications of polymers.
        """)
        
    with st.sidebar.expander("Molecular Weight"):
        st.write("""
        **Molecular Weight (g/mol)** is the mass of a molecule.
        It affects mechanical properties, processability, and many other characteristics.
        """)
    
    st.sidebar.markdown("<div class='sub-header'>Recent Predictions</div>", unsafe_allow_html=True)
    history = get_prediction_history(5)
    if history:
        for i, item in enumerate(history):
            smiles = item["smiles"]
            timestamp = item["timestamp"].strftime("%Y-%m-%d %H:%M")
            with st.sidebar.expander(f"#{i+1}: {smiles[:15]}... ({timestamp})"):
                st.code(smiles, language="text")
                for prop, val in item["predictions"].items():
                    st.write(f"**{prop}**: {val:.4f}")
    else:
        st.sidebar.write("No prediction history available.")

# Example SMILES section removed

# Property visualization
def visualize_properties(results):
    st.markdown("<div class='sub-header'>Property Visualization</div>", unsafe_allow_html=True)
    
    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame([results])
    
    # Normalize values for radar chart
    property_ranges = {
        "Tensile Strength (MPa)": (0, 200),
        "Ionization Energy (eV)": (5, 15),
        "Electron Affinity (eV)": (0, 5),
        "logP": (-5, 10),
        "Refractive Index": (1, 2),
        "Molecular Weight (g/mol)": (0, 5000)
    }
    
    normalized_values = {}
    for prop, value in results.items():
        min_val, max_val = property_ranges.get(prop, (0, 1))
        normalized = (value - min_val) / (max_val - min_val)
        normalized_values[prop] = max(0, min(normalized, 1))  # Clamp between 0 and 1
    
    # Display as gauge charts
    cols = st.columns(3)
    for i, (prop, norm_val) in enumerate(normalized_values.items()):
        with cols[i % 3]:
            st.markdown(f"<div class='property-card'>", unsafe_allow_html=True)
            st.markdown(f"<h4>{prop}</h4>", unsafe_allow_html=True)
            # Ensure the value is a float between 0 and 1
            st.progress(float(norm_val))
            st.markdown(f"<h3 style='text-align: center;'>{results[prop]:.4f}</h3>", unsafe_allow_html=True)
            st.markdown("</div>", unsafe_allow_html=True)
    
    # Add a bar chart comparing the properties
    normalized_df = pd.DataFrame({
        'Property': list(normalized_values.keys()),
        'Normalized Value': list(normalized_values.values()),
        'Actual Value': [results[prop] for prop in normalized_values.keys()]
    })
    
    st.bar_chart(normalized_df.set_index('Property')['Normalized Value'])

# Main function
def show():
    # Initialize session state for SMILES input
    if 'smiles_input' not in st.session_state:
        st.session_state.smiles_input = ""
    
    # Main header
    st.markdown("<div class='main-header'>πŸ§ͺ Polymer Property Prediction</div>", unsafe_allow_html=True)
    
    # Sidebar
    show_sidebar()
    
    # Input section
    st.markdown("<div class='sub-header'>Input Your Polymer</div>", unsafe_allow_html=True)
    
    # SMILES input with example dropdown
    col1, col2 = st.columns([3, 1])
    with col1:
        smiles_input = st.text_input("Enter SMILES Representation", 
                                      value=st.session_state.smiles_input,
                                      help="SMILES (Simplified Molecular Input Line Entry System) is a notation representing molecular structure.")
    with col2:
        st.markdown("<br>", unsafe_allow_html=True)
        if st.button("Clear", key="clear_button"):
            st.session_state.smiles_input = ""
    

    # Input validation
    is_valid = False
    if smiles_input:
        mol = Chem.MolFromSmiles(smiles_input)
        is_valid = mol is not None
        
        if is_valid:
            st.session_state.smiles_input = smiles_input
            col1, col2 = st.columns([1, 2])
            with col1:
                mol_img = get_molecule_image(smiles_input)
                if mol_img:
                    st.markdown(f"<img src='data:image/png;base64,{mol_img}' style='max-width:100%;'>", unsafe_allow_html=True)
            with col2:
                st.markdown("<div class='info-box'>", unsafe_allow_html=True)
                st.markdown("### Molecule Properties")
                st.write(f"**Formula:** {Chem.rdMolDescriptors.CalcMolFormula(mol)}")
                st.write(f"**Rings:** {Descriptors.RingCount(mol)}")
                st.write(f"**H-Bond Donors:** {Descriptors.NumHDonors(mol)}")
                st.write(f"**H-Bond Acceptors:** {Descriptors.NumHAcceptors(mol)}")
                st.markdown("</div>", unsafe_allow_html=True)
        else:
            st.warning("Invalid SMILES string. Please check your input.")
    
    # Prediction button
    run_prediction = st.button("πŸ” Predict Properties", disabled=not is_valid, key="predict_button")
    
    if run_prediction:
        try:
            # Load resources with progress indication
            progress_bar = st.progress(0)
            status_text = st.empty()
            
            # Step 1: Load models
            status_text.text("Loading models...")
            model = load_model()
            tokenizer, chemberta = load_chemberta()
            scalers = load_scalers()
            progress_bar.progress(0.25)  # Ensure float value between 0 and 1
            time.sleep(0.5)  # Simulate processing time for better UX
            
            # Step 2: Compute molecular features
            status_text.text("Computing molecular features...")
            descriptors = compute_descriptors(smiles_input)
            descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
            fingerprint = get_morgan_fingerprint(smiles_input)
            fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32)
            features = torch.cat([descriptors_tensor, fingerprint_tensor], dim=1).to(device)
            progress_bar.progress(0.50)  # Ensure float value between 0 and 1
            time.sleep(0.5)  # Simulate processing time
            
            # Step 3: Generate embeddings
            status_text.text("Generating ChemBERTa embeddings...")
            embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
            progress_bar.progress(0.75)  # Ensure float value between 0 and 1
            time.sleep(0.5)  # Simulate processing time
            
            # Step 4: Make predictions
            status_text.text("Making predictions...")
            with torch.no_grad():
                preds = model(embedding, features)
            
            preds_np = preds.cpu().numpy()
            keys = list(scalers.keys())
            preds_rescaled = np.concatenate([
                scalers[keys[i]].inverse_transform(preds_np[:, [i]])
                for i in range(6)
            ], axis=1)
            
            results = {key: val for key, val in zip(keys, preds_rescaled.flatten())}
            progress_bar.progress(1.0)  # Ensure float value between 0 and 1
            status_text.empty()
            
            # Save to database
            mol_img = get_molecule_image(smiles_input)
            save_to_db(smiles_input, results, mol_img)
            
            # Display results
            st.success("βœ… Prediction completed successfully!")
            
            # Visualize results
            visualize_properties(results)
            
            # Detailed results in expandable section
            with st.expander("View Detailed Results"):
                result_df = pd.DataFrame({
                    'Property': list(results.keys()),
                    'Predicted Value': [f"{val:.4f}" for val in results.values()]
                })
                st.table(result_df)
                
                # Export options
                csv = result_df.to_csv(index=False)
                st.download_button(
                    label="Download Results as CSV",
                    data=csv,
                    file_name=f"polymer_prediction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
                    mime="text/csv"
                )

        except Exception as e:
            st.error(f"Prediction failed: {str(e)}")
            st.code(str(e))
    
    # Footer
    st.markdown("""
    <div style="text-align: center; margin-top: 3rem; padding-top: 1rem; border-top: 1px solid #ccc; color: #666;">
        <p>Polymer Property Prediction Tool - Β© 2025</p>
    </div>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    show()