transpolymer commited on
Commit
5909561
·
verified ·
1 Parent(s): ddd8517

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +396 -56
prediction.py CHANGED
@@ -5,10 +5,14 @@ import numpy as np
5
  import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
- from rdkit.Chem import Descriptors, AllChem
9
  from datetime import datetime
10
  from db import get_database
11
  import random
 
 
 
 
12
 
13
  # Set seeds
14
  random.seed(42)
@@ -19,22 +23,116 @@ torch.backends.cudnn.benchmark = False
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Load ChemBERTa
23
  @st.cache_resource
24
  def load_chemberta():
25
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
- model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()
 
27
  return tokenizer, model
28
 
29
  # Load scalers
30
- scalers = {
31
- "Tensile Strength(Mpa)": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
32
- "Ionization Energy(eV)": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
33
- "Electron Affinity(eV)": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
34
- "logP": joblib.load("scaler_LogP.joblib"),
35
- "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
36
- "Molecular Weight(g/mol)": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
37
- }
 
 
38
 
39
  # Transformer model
40
  class TransformerRegressor(nn.Module):
@@ -54,23 +152,24 @@ class TransformerRegressor(nn.Module):
54
  nn.Linear(128, output_dim)
55
  )
56
 
57
- def forward(self, x,feat):
58
- feat_emb=self.feat_proj(feat)
59
- stacked=torch.stack([x,feat_emb],dim=1)
60
- encoded=self.transformer_encoder(stacked)
61
- aggregated=encoded.mean(dim=1)
62
  return self.regression_head(aggregated)
63
 
64
  # Load model
65
  @st.cache_resource
66
  def load_model():
67
- model = TransformerRegressor()
68
- try:
69
- state_dict = torch.load("transformer_model.bin", map_location=device)
70
- model.load_state_dict(state_dict)
71
- model.eval().to(device)
72
- except Exception as e:
73
- raise ValueError(f"Failed to load model: {e}")
 
74
  return model
75
 
76
  # RDKit descriptors
@@ -98,70 +197,311 @@ def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
98
  if mol is None:
99
  raise ValueError("Invalid SMILES string.")
100
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
101
- return np.array(fp, dtype=np.float32).reshape(1,-1)
102
 
103
  # ChemBERTa embedding
104
  def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
105
  inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
 
106
  with torch.no_grad():
107
  outputs = chemberta(**inputs)
108
- return outputs.last_hidden_state.mean(dim=1).to(device)
109
 
110
  # Save to DB
111
- def save_to_db(smiles, predictions):
112
  predictions_clean = {k: float(v) for k, v in predictions.items()}
113
  doc = {
114
  "smiles": smiles,
115
  "predictions": predictions_clean,
116
  "timestamp": datetime.now()
117
  }
 
 
 
118
  db = get_database()
119
  db["polymer_predictions"].insert_one(doc)
 
120
 
121
- # Streamlit UI
122
- def show():
123
- st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
124
- st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- smiles_input = st.text_input("Enter SMILES Representation of Polymer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- if st.button("Predict"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  try:
 
 
 
 
 
 
130
  model = load_model()
131
  tokenizer, chemberta = load_chemberta()
132
-
133
- mol = Chem.MolFromSmiles(smiles_input)
134
- if mol is None:
135
- st.error("Invalid SMILES string.")
136
- return
137
-
138
  descriptors = compute_descriptors(smiles_input)
139
- descriptors_tensor=torch.tensor(descriptors,dtype=torch.float32).unsqueeze(0)
140
  fingerprint = get_morgan_fingerprint(smiles_input)
141
- fingerprint_tensor=torch.tensor(fingerprint,dtype=torch.float32)
142
- features=torch.cat([descriptors_tensor,fingerprint_tensor],dim=1).to(device)
 
 
 
 
 
143
  embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
144
-
145
-
 
 
 
146
  with torch.no_grad():
147
- preds = model(embedding,features)
148
-
149
- preds_np=preds.cpu().numpy()
150
  keys = list(scalers.keys())
151
  preds_rescaled = np.concatenate([
152
  scalers[keys[i]].inverse_transform(preds_np[:, [i]])
153
  for i in range(6)
154
  ], axis=1)
155
-
156
- results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
157
-
158
- st.success("Predicted Properties:")
159
- for key, val in results.items():
160
- st.markdown(f"**{key}**: {val}")
161
-
162
- save_to_db(smiles_input, results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  except Exception as e:
165
- st.error(f"Prediction failed: {e}")
 
 
 
 
 
 
 
 
166
 
167
-
 
 
5
  import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
+ from rdkit.Chem import Descriptors, AllChem, Draw
9
  from datetime import datetime
10
  from db import get_database
11
  import random
12
+ import pandas as pd
13
+ import time
14
+ import base64
15
+ from io import BytesIO
16
 
17
  # Set seeds
18
  random.seed(42)
 
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
+ # Page styling and configuration
27
+ st.set_page_config(
28
+ page_title="Polymer Property Prediction",
29
+ page_icon="🧪",
30
+ layout="wide",
31
+ initial_sidebar_state="expanded"
32
+ )
33
+
34
+ # Custom CSS
35
+ st.markdown("""
36
+ <style>
37
+ .main-header {
38
+ font-size: 2.5rem;
39
+ font-weight: 700;
40
+ color: #4CAF50;
41
+ text-align: center;
42
+ margin-bottom: 1rem;
43
+ background: linear-gradient(90deg, #f8f9fa 0%, #e9ecef 100%);
44
+ padding: 1.5rem 0;
45
+ border-radius: 10px;
46
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
47
+ }
48
+ .sub-header {
49
+ font-size: 1.5rem;
50
+ font-weight: 600;
51
+ color: #2E7D32;
52
+ margin-bottom: 0.5rem;
53
+ }
54
+ .property-card {
55
+ background-color: #f1f8e9;
56
+ border-radius: 10px;
57
+ padding: 1rem;
58
+ margin: 0.5rem 0;
59
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
60
+ transition: transform 0.3s ease;
61
+ }
62
+ .property-card:hover {
63
+ transform: translateY(-5px);
64
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
65
+ }
66
+ .loader {
67
+ border: 16px solid #f3f3f3;
68
+ border-radius: 50%;
69
+ border-top: 16px solid #3498db;
70
+ width: 50px;
71
+ height: 50px;
72
+ animation: spin 2s linear infinite;
73
+ margin: 20px auto;
74
+ }
75
+ .info-box {
76
+ background-color: #e3f2fd;
77
+ border-left: 5px solid #2196f3;
78
+ padding: 1rem;
79
+ margin: 1rem 0;
80
+ border-radius: 5px;
81
+ }
82
+ .tooltip {
83
+ position: relative;
84
+ display: inline-block;
85
+ border-bottom: 1px dotted black;
86
+ }
87
+ .tooltip .tooltiptext {
88
+ visibility: hidden;
89
+ width: 120px;
90
+ background-color: black;
91
+ color: #fff;
92
+ text-align: center;
93
+ border-radius: 6px;
94
+ padding: 5px 0;
95
+ position: absolute;
96
+ z-index: 1;
97
+ bottom: 125%;
98
+ left: 50%;
99
+ margin-left: -60px;
100
+ opacity: 0;
101
+ transition: opacity 0.3s;
102
+ }
103
+ .tooltip:hover .tooltiptext {
104
+ visibility: visible;
105
+ opacity: 1;
106
+ }
107
+ @keyframes spin {
108
+ 0% { transform: rotate(0deg); }
109
+ 100% { transform: rotate(360deg); }
110
+ }
111
+ .stProgress > div > div > div > div {
112
+ background-color: #4CAF50 !important;
113
+ }
114
+ </style>
115
+ """, unsafe_allow_html=True)
116
+
117
  # Load ChemBERTa
118
  @st.cache_resource
119
  def load_chemberta():
120
+ with st.spinner("Loading ChemBERTa model..."):
121
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
122
+ model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()
123
  return tokenizer, model
124
 
125
  # Load scalers
126
+ @st.cache_resource
127
+ def load_scalers():
128
+ return {
129
+ "Tensile Strength (MPa)": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
130
+ "Ionization Energy (eV)": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
131
+ "Electron Affinity (eV)": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
132
+ "logP": joblib.load("scaler_LogP.joblib"),
133
+ "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
134
+ "Molecular Weight (g/mol)": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
135
+ }
136
 
137
  # Transformer model
138
  class TransformerRegressor(nn.Module):
 
152
  nn.Linear(128, output_dim)
153
  )
154
 
155
+ def forward(self, x, feat):
156
+ feat_emb = self.feat_proj(feat)
157
+ stacked = torch.stack([x, feat_emb], dim=1)
158
+ encoded = self.transformer_encoder(stacked)
159
+ aggregated = encoded.mean(dim=1)
160
  return self.regression_head(aggregated)
161
 
162
  # Load model
163
  @st.cache_resource
164
  def load_model():
165
+ with st.spinner("Loading prediction model..."):
166
+ model = TransformerRegressor()
167
+ try:
168
+ state_dict = torch.load("transformer_model.bin", map_location=device)
169
+ model.load_state_dict(state_dict)
170
+ model.eval().to(device)
171
+ except Exception as e:
172
+ raise ValueError(f"Failed to load model: {e}")
173
  return model
174
 
175
  # RDKit descriptors
 
197
  if mol is None:
198
  raise ValueError("Invalid SMILES string.")
199
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
200
+ return np.array(fp, dtype=np.float32).reshape(1, -1)
201
 
202
  # ChemBERTa embedding
203
  def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
204
  inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
205
+ inputs = {k: v.to(device) for k, v in inputs.items()}
206
  with torch.no_grad():
207
  outputs = chemberta(**inputs)
208
+ return outputs.last_hidden_state.mean(dim=1)
209
 
210
  # Save to DB
211
+ def save_to_db(smiles, predictions, mol_image=None):
212
  predictions_clean = {k: float(v) for k, v in predictions.items()}
213
  doc = {
214
  "smiles": smiles,
215
  "predictions": predictions_clean,
216
  "timestamp": datetime.now()
217
  }
218
+ if mol_image:
219
+ doc["molecule_image"] = mol_image
220
+
221
  db = get_database()
222
  db["polymer_predictions"].insert_one(doc)
223
+ return doc["_id"]
224
 
225
+ # Get molecule image as base64
226
+ def get_molecule_image(smiles):
227
+ mol = Chem.MolFromSmiles(smiles)
228
+ if mol:
229
+ img = Draw.MolToImage(mol, size=(300, 300))
230
+ buffered = BytesIO()
231
+ img.save(buffered, format="PNG")
232
+ return base64.b64encode(buffered.getvalue()).decode()
233
+ return None
234
+
235
+ # Example SMILES for users to try
236
+ EXAMPLE_SMILES = [
237
+ "CC(C)(C)CC(C)(C)C", # Polyisobutylene
238
+ "CCC(C)CC(C)CC", # Polypropylene
239
+ "CCCCCCCC", # Polyethylene
240
+ "CC(C)(c1ccccc1)C", # Polystyrene
241
+ "COC(=O)C(C)OC(=O)C", # PMMA
242
+ ]
243
+
244
+ # Get history from database
245
+ def get_prediction_history(limit=5):
246
+ db = get_database()
247
+ history = list(db["polymer_predictions"].find().sort("timestamp", -1).limit(limit))
248
+ return history
249
 
250
+ # Sidebar
251
+ def show_sidebar():
252
+ st.sidebar.markdown("<div class='sub-header'>About This Tool</div>", unsafe_allow_html=True)
253
+ st.sidebar.info("""
254
+ This tool predicts key properties of polymers based on their SMILES representation.
255
+
256
+ It uses a transformer neural network combined with ChemBERTa embeddings and molecular descriptors.
257
+ """)
258
+
259
+ st.sidebar.markdown("<div class='sub-header'>Property Explanations</div>", unsafe_allow_html=True)
260
+
261
+ with st.sidebar.expander("Tensile Strength"):
262
+ st.write("""
263
+ **Tensile Strength (MPa)** measures the maximum stress a material can withstand before breaking.
264
+ Higher values indicate stronger materials.
265
+ """)
266
+
267
+ with st.sidebar.expander("Ionization Energy"):
268
+ st.write("""
269
+ **Ionization Energy (eV)** is the energy required to remove an electron from an atom or molecule.
270
+ It affects chemical reactivity and stability.
271
+ """)
272
+
273
+ with st.sidebar.expander("Electron Affinity"):
274
+ st.write("""
275
+ **Electron Affinity (eV)** measures how much energy is released when an electron is added to a neutral atom.
276
+ It influences a polymer's electrical properties.
277
+ """)
278
+
279
+ with st.sidebar.expander("logP"):
280
+ st.write("""
281
+ **logP** is the partition coefficient that measures how a substance distributes between water and lipid phases.
282
+ It affects solubility and permeability of polymers.
283
+ """)
284
+
285
+ with st.sidebar.expander("Refractive Index"):
286
+ st.write("""
287
+ **Refractive Index** measures how light propagates through the material.
288
+ It's important for optical applications of polymers.
289
+ """)
290
+
291
+ with st.sidebar.expander("Molecular Weight"):
292
+ st.write("""
293
+ **Molecular Weight (g/mol)** is the mass of a molecule.
294
+ It affects mechanical properties, processability, and many other characteristics.
295
+ """)
296
+
297
+ st.sidebar.markdown("<div class='sub-header'>Recent Predictions</div>", unsafe_allow_html=True)
298
+ history = get_prediction_history(5)
299
+ if history:
300
+ for i, item in enumerate(history):
301
+ smiles = item["smiles"]
302
+ timestamp = item["timestamp"].strftime("%Y-%m-%d %H:%M")
303
+ with st.sidebar.expander(f"#{i+1}: {smiles[:15]}... ({timestamp})"):
304
+ st.code(smiles, language="text")
305
+ for prop, val in item["predictions"].items():
306
+ st.write(f"**{prop}**: {val:.4f}")
307
+ else:
308
+ st.sidebar.write("No prediction history available.")
309
 
310
+ # Show example SMILES
311
+ def show_examples():
312
+ st.markdown("<div class='sub-header'>Example SMILES</div>", unsafe_allow_html=True)
313
+ cols = st.columns(len(EXAMPLE_SMILES))
314
+
315
+ for i, (col, smiles) in enumerate(zip(cols, EXAMPLE_SMILES)):
316
+ polymer_name = ["Polyisobutylene", "Polypropylene", "Polyethylene", "Polystyrene", "PMMA"][i]
317
+ with col:
318
+ if st.button(f"{polymer_name}", key=f"example_{i}"):
319
+ st.session_state.smiles_input = smiles
320
+ st.experimental_rerun()
321
+
322
+ # Property visualization
323
+ def visualize_properties(results):
324
+ st.markdown("<div class='sub-header'>Property Visualization</div>", unsafe_allow_html=True)
325
+
326
+ # Convert to DataFrame for easier manipulation
327
+ df = pd.DataFrame([results])
328
+
329
+ # Normalize values for radar chart
330
+ property_ranges = {
331
+ "Tensile Strength (MPa)": (0, 200),
332
+ "Ionization Energy (eV)": (5, 15),
333
+ "Electron Affinity (eV)": (0, 5),
334
+ "logP": (-5, 10),
335
+ "Refractive Index": (1, 2),
336
+ "Molecular Weight (g/mol)": (0, 5000)
337
+ }
338
+
339
+ normalized_values = {}
340
+ for prop, value in results.items():
341
+ min_val, max_val = property_ranges.get(prop, (0, 1))
342
+ normalized = (value - min_val) / (max_val - min_val)
343
+ normalized_values[prop] = max(0, min(normalized, 1)) # Clamp between 0 and 1
344
+
345
+ # Display as gauge charts
346
+ cols = st.columns(3)
347
+ for i, (prop, norm_val) in enumerate(normalized_values.items()):
348
+ with cols[i % 3]:
349
+ st.markdown(f"<div class='property-card'>", unsafe_allow_html=True)
350
+ st.markdown(f"<h4>{prop}</h4>", unsafe_allow_html=True)
351
+ st.progress(norm_val)
352
+ st.markdown(f"<h3 style='text-align: center;'>{results[prop]:.4f}</h3>", unsafe_allow_html=True)
353
+ st.markdown("</div>", unsafe_allow_html=True)
354
+
355
+ # Add a bar chart comparing the properties
356
+ normalized_df = pd.DataFrame({
357
+ 'Property': list(normalized_values.keys()),
358
+ 'Normalized Value': list(normalized_values.values()),
359
+ 'Actual Value': [results[prop] for prop in normalized_values.keys()]
360
+ })
361
+
362
+ st.bar_chart(normalized_df.set_index('Property')['Normalized Value'])
363
+
364
+ # Main function
365
+ def show():
366
+ # Initialize session state for SMILES input
367
+ if 'smiles_input' not in st.session_state:
368
+ st.session_state.smiles_input = ""
369
+
370
+ # Main header
371
+ st.markdown("<div class='main-header'>🧪 Polymer Property Prediction</div>", unsafe_allow_html=True)
372
+
373
+ # Sidebar
374
+ show_sidebar()
375
+
376
+ # Input section
377
+ st.markdown("<div class='sub-header'>Input Your Polymer</div>", unsafe_allow_html=True)
378
+
379
+ # SMILES input with example dropdown
380
+ col1, col2 = st.columns([3, 1])
381
+ with col1:
382
+ smiles_input = st.text_input("Enter SMILES Representation",
383
+ value=st.session_state.smiles_input,
384
+ help="SMILES (Simplified Molecular Input Line Entry System) is a notation representing molecular structure.")
385
+ with col2:
386
+ st.markdown("<br>", unsafe_allow_html=True)
387
+ if st.button("Clear", key="clear_button"):
388
+ st.session_state.smiles_input = ""
389
+ st.experimental_rerun()
390
+
391
+ # Example SMILES section
392
+ show_examples()
393
+
394
+ # Input validation
395
+ is_valid = False
396
+ if smiles_input:
397
+ mol = Chem.MolFromSmiles(smiles_input)
398
+ is_valid = mol is not None
399
+
400
+ if is_valid:
401
+ st.session_state.smiles_input = smiles_input
402
+ col1, col2 = st.columns([1, 2])
403
+ with col1:
404
+ mol_img = get_molecule_image(smiles_input)
405
+ if mol_img:
406
+ st.markdown(f"<img src='data:image/png;base64,{mol_img}' style='max-width:100%;'>", unsafe_allow_html=True)
407
+ with col2:
408
+ st.markdown("<div class='info-box'>", unsafe_allow_html=True)
409
+ st.markdown("### Molecule Properties")
410
+ st.write(f"**Formula:** {Chem.rdMolDescriptors.CalcMolFormula(mol)}")
411
+ st.write(f"**Molecular Weight:** {Descriptors.MolWt(mol):.2f} g/mol")
412
+ st.write(f"**Rings:** {Descriptors.RingCount(mol)}")
413
+ st.write(f"**H-Bond Donors:** {Descriptors.NumHDonors(mol)}")
414
+ st.write(f"**H-Bond Acceptors:** {Descriptors.NumHAcceptors(mol)}")
415
+ st.markdown("</div>", unsafe_allow_html=True)
416
+ else:
417
+ st.warning("Invalid SMILES string. Please check your input.")
418
+
419
+ # Prediction button
420
+ run_prediction = st.button("🔍 Predict Properties", disabled=not is_valid, key="predict_button")
421
+
422
+ if run_prediction:
423
  try:
424
+ # Load resources with progress indication
425
+ progress_bar = st.progress(0)
426
+ status_text = st.empty()
427
+
428
+ # Step 1: Load models
429
+ status_text.text("Loading models...")
430
  model = load_model()
431
  tokenizer, chemberta = load_chemberta()
432
+ scalers = load_scalers()
433
+ progress_bar.progress(25)
434
+ time.sleep(0.5) # Simulate processing time for better UX
435
+
436
+ # Step 2: Compute molecular features
437
+ status_text.text("Computing molecular features...")
438
  descriptors = compute_descriptors(smiles_input)
439
+ descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
440
  fingerprint = get_morgan_fingerprint(smiles_input)
441
+ fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32)
442
+ features = torch.cat([descriptors_tensor, fingerprint_tensor], dim=1).to(device)
443
+ progress_bar.progress(50)
444
+ time.sleep(0.5) # Simulate processing time
445
+
446
+ # Step 3: Generate embeddings
447
+ status_text.text("Generating ChemBERTa embeddings...")
448
  embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
449
+ progress_bar.progress(75)
450
+ time.sleep(0.5) # Simulate processing time
451
+
452
+ # Step 4: Make predictions
453
+ status_text.text("Making predictions...")
454
  with torch.no_grad():
455
+ preds = model(embedding, features)
456
+
457
+ preds_np = preds.cpu().numpy()
458
  keys = list(scalers.keys())
459
  preds_rescaled = np.concatenate([
460
  scalers[keys[i]].inverse_transform(preds_np[:, [i]])
461
  for i in range(6)
462
  ], axis=1)
463
+
464
+ results = {key: val for key, val in zip(keys, preds_rescaled.flatten())}
465
+ progress_bar.progress(100)
466
+ status_text.empty()
467
+
468
+ # Save to database
469
+ mol_img = get_molecule_image(smiles_input)
470
+ save_to_db(smiles_input, results, mol_img)
471
+
472
+ # Display results
473
+ st.success("✅ Prediction completed successfully!")
474
+
475
+ # Visualize results
476
+ visualize_properties(results)
477
+
478
+ # Detailed results in expandable section
479
+ with st.expander("View Detailed Results"):
480
+ result_df = pd.DataFrame({
481
+ 'Property': list(results.keys()),
482
+ 'Predicted Value': [f"{val:.4f}" for val in results.values()]
483
+ })
484
+ st.table(result_df)
485
+
486
+ # Export options
487
+ csv = result_df.to_csv(index=False)
488
+ st.download_button(
489
+ label="Download Results as CSV",
490
+ data=csv,
491
+ file_name=f"polymer_prediction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
492
+ mime="text/csv"
493
+ )
494
 
495
  except Exception as e:
496
+ st.error(f"Prediction failed: {str(e)}")
497
+ st.code(str(e))
498
+
499
+ # Footer
500
+ st.markdown("""
501
+ <div style="text-align: center; margin-top: 3rem; padding-top: 1rem; border-top: 1px solid #ccc; color: #666;">
502
+ <p>Polymer Property Prediction Tool - © 2025</p>
503
+ </div>
504
+ """, unsafe_allow_html=True)
505
 
506
+ if __name__ == "__main__":
507
+ show()