Spaces:
Sleeping
Sleeping
File size: 17,727 Bytes
4b283df 84dad8f 5e9e549 b73e3e2 4b283df 84dad8f 5909561 4b283df 8e77cfc cf36af6 5909561 f8c8eb7 76beebb cf36af6 5834d19 b69e05d 77a4bbb 5909561 76beebb a08c2a4 5909561 a08c2a4 b69e05d 76beebb 5909561 5834d19 76beebb 5e9e549 a715cd6 5e9e549 a715cd6 3de6f45 76beebb 23348f0 3de6f45 eea9e94 3de6f45 5e9e549 3de6f45 eea9e94 3de6f45 5e9e549 5909561 23348f0 eea9e94 76beebb 4b283df 5909561 4b283df 77a4bbb 76beebb 4b283df 76beebb 4b283df 5834d19 76beebb 4b283df 76beebb 23348f0 af4f3b0 5909561 af4f3b0 76beebb a715cd6 76beebb 5909561 5834d19 a715cd6 5909561 4b283df 76beebb 5909561 5834d19 4b283df 5834d19 4b283df 5909561 4b283df 5909561 4b283df 5909561 8a98757 8ae7b7c 5909561 4b283df 5909561 4b283df 8a98757 8ae7b7c 5909561 3846070 5909561 f8e3d22 5909561 4b283df 5909561 a71233a 76beebb 5909561 3846070 5909561 23348f0 5909561 23348f0 5909561 3846070 5909561 5834d19 3846070 5909561 4b283df 5909561 4b283df 23348f0 5834d19 4b283df 5909561 3846070 5909561 4b283df 5909561 77a4bbb 5909561 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 |
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem, Draw
from datetime import datetime
from db import get_database
import random
import pandas as pd
import time
import base64
from io import BytesIO
# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Page styling and configuration
st.set_page_config(
page_title="Polymer Property Prediction",
page_icon="π§ͺ",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
font-weight: 700;
color: #4CAF50;
text-align: center;
margin-bottom: 1rem;
background: linear-gradient(90deg, #f8f9fa 0%, #e9ecef 100%);
padding: 1.5rem 0;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.sub-header {
font-size: 1.5rem;
font-weight: 600;
color: #2E7D32;
margin-bottom: 0.5rem;
}
.property-card {
background-color: #f1f8e9;
border-radius: 10px;
padding: 1rem;
margin: 0.5rem 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
transition: transform 0.3s ease;
}
.property-card:hover {
transform: translateY(-5px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.loader {
border: 16px solid #f3f3f3;
border-radius: 50%;
border-top: 16px solid #3498db;
width: 50px;
height: 50px;
animation: spin 2s linear infinite;
margin: 20px auto;
}
.info-box {
background-color: #e3f2fd;
border-left: 5px solid #2196f3;
padding: 1rem;
margin: 1rem 0;
border-radius: 5px;
}
.tooltip {
position: relative;
display: inline-block;
border-bottom: 1px dotted black;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 120px;
background-color: black;
color: #fff;
text-align: center;
border-radius: 6px;
padding: 5px 0;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
margin-left: -60px;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.stProgress > div > div > div > div {
background-color: #4CAF50 !important;
}
</style>
""", unsafe_allow_html=True)
# Load ChemBERTa
@st.cache_resource
def load_chemberta():
with st.spinner("Loading ChemBERTa model..."):
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()
return tokenizer, model
# Load scalers
@st.cache_resource
def load_scalers():
return {
"Tensile Strength (MPa)": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
"Ionization Energy (eV)": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
"Electron Affinity (eV)": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
"logP": joblib.load("scaler_LogP.joblib"),
"Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
"Molecular Weight (g/mol)": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
}
# Transformer model
class TransformerRegressor(nn.Module):
def __init__(self, feat_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
super().__init__()
self.feat_proj = nn.Linear(feat_dim, embedding_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
dropout=0.1, batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(embedding_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x, feat):
feat_emb = self.feat_proj(feat)
stacked = torch.stack([x, feat_emb], dim=1)
encoded = self.transformer_encoder(stacked)
aggregated = encoded.mean(dim=1)
return self.regression_head(aggregated)
# Load model
@st.cache_resource
def load_model():
with st.spinner("Loading prediction model..."):
model = TransformerRegressor()
try:
state_dict = torch.load("transformer_model.bin", map_location=device)
model.load_state_dict(state_dict)
model.eval().to(device)
except Exception as e:
raise ValueError(f"Failed to load model: {e}")
return model
# RDKit descriptors
def compute_descriptors(smiles: str):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string.")
desc = [
Descriptors.MolWt(mol),
Descriptors.MolLogP(mol),
Descriptors.TPSA(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.NumHDonors(mol),
Descriptors.NumHAcceptors(mol),
Descriptors.FractionCSP3(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.RingCount(mol),
Descriptors.MolMR(mol)
]
return np.array(desc, dtype=np.float32)
# Morgan fingerprint
def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string.")
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
return np.array(fp, dtype=np.float32).reshape(1, -1)
# ChemBERTa embedding
def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = chemberta(**inputs)
return outputs.last_hidden_state.mean(dim=1)
# Save to DB
def save_to_db(smiles, predictions, mol_image=None):
predictions_clean = {k: float(v) for k, v in predictions.items()}
doc = {
"smiles": smiles,
"predictions": predictions_clean,
"timestamp": datetime.now()
}
if mol_image:
doc["molecule_image"] = mol_image
db = get_database()
db["polymer_predictions"].insert_one(doc)
return doc["_id"]
# Get molecule image as base64
def get_molecule_image(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol:
img = Draw.MolToImage(mol, size=(300, 300))
buffered = BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
return None
# Removed example SMILES
# Get history from database
def get_prediction_history(limit=5):
db = get_database()
history = list(db["polymer_predictions"].find().sort("timestamp", -1).limit(limit))
return history
# Sidebar
def show_sidebar():
st.sidebar.markdown("<div class='sub-header'>About This Tool</div>", unsafe_allow_html=True)
st.sidebar.info("""
This tool predicts key properties of polymers based on their SMILES representation.
It uses a transformer neural network combined with ChemBERTa embeddings and molecular descriptors.
""")
st.sidebar.markdown("<div class='sub-header'>Property Explanations</div>", unsafe_allow_html=True)
with st.sidebar.expander("Tensile Strength"):
st.write("""
**Tensile Strength (MPa)** measures the maximum stress a material can withstand before breaking.
Higher values indicate stronger materials.
""")
with st.sidebar.expander("Ionization Energy"):
st.write("""
**Ionization Energy (eV)** is the energy required to remove an electron from an atom or molecule.
It affects chemical reactivity and stability.
""")
with st.sidebar.expander("Electron Affinity"):
st.write("""
**Electron Affinity (eV)** measures how much energy is released when an electron is added to a neutral atom.
It influences a polymer's electrical properties.
""")
with st.sidebar.expander("logP"):
st.write("""
**logP** is the partition coefficient that measures how a substance distributes between water and lipid phases.
It affects solubility and permeability of polymers.
""")
with st.sidebar.expander("Refractive Index"):
st.write("""
**Refractive Index** measures how light propagates through the material.
It's important for optical applications of polymers.
""")
with st.sidebar.expander("Molecular Weight"):
st.write("""
**Molecular Weight (g/mol)** is the mass of a molecule.
It affects mechanical properties, processability, and many other characteristics.
""")
st.sidebar.markdown("<div class='sub-header'>Recent Predictions</div>", unsafe_allow_html=True)
history = get_prediction_history(5)
if history:
for i, item in enumerate(history):
smiles = item["smiles"]
timestamp = item["timestamp"].strftime("%Y-%m-%d %H:%M")
with st.sidebar.expander(f"#{i+1}: {smiles[:15]}... ({timestamp})"):
st.code(smiles, language="text")
for prop, val in item["predictions"].items():
st.write(f"**{prop}**: {val:.4f}")
else:
st.sidebar.write("No prediction history available.")
# Example SMILES section removed
# Property visualization
def visualize_properties(results):
st.markdown("<div class='sub-header'>Property Visualization</div>", unsafe_allow_html=True)
# Convert to DataFrame for easier manipulation
df = pd.DataFrame([results])
# Normalize values for radar chart
property_ranges = {
"Tensile Strength (MPa)": (0, 200),
"Ionization Energy (eV)": (5, 15),
"Electron Affinity (eV)": (0, 5),
"logP": (-5, 10),
"Refractive Index": (1, 2),
"Molecular Weight (g/mol)": (0, 5000)
}
normalized_values = {}
for prop, value in results.items():
min_val, max_val = property_ranges.get(prop, (0, 1))
normalized = (value - min_val) / (max_val - min_val)
normalized_values[prop] = max(0, min(normalized, 1)) # Clamp between 0 and 1
# Display as gauge charts
cols = st.columns(3)
for i, (prop, norm_val) in enumerate(normalized_values.items()):
with cols[i % 3]:
st.markdown(f"<div class='property-card'>", unsafe_allow_html=True)
st.markdown(f"<h4>{prop}</h4>", unsafe_allow_html=True)
# Ensure the value is a float between 0 and 1
st.progress(float(norm_val))
st.markdown(f"<h3 style='text-align: center;'>{results[prop]:.4f}</h3>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# Add a bar chart comparing the properties
normalized_df = pd.DataFrame({
'Property': list(normalized_values.keys()),
'Normalized Value': list(normalized_values.values()),
'Actual Value': [results[prop] for prop in normalized_values.keys()]
})
st.bar_chart(normalized_df.set_index('Property')['Normalized Value'])
# Main function
def show():
# Initialize session state for SMILES input
if 'smiles_input' not in st.session_state:
st.session_state.smiles_input = ""
# Main header
st.markdown("<div class='main-header'>π§ͺ Polymer Property Prediction</div>", unsafe_allow_html=True)
# Sidebar
show_sidebar()
# Input section
st.markdown("<div class='sub-header'>Input Your Polymer</div>", unsafe_allow_html=True)
# SMILES input with example dropdown
col1, col2 = st.columns([3, 1])
with col1:
smiles_input = st.text_input("Enter SMILES Representation",
value=st.session_state.smiles_input,
help="SMILES (Simplified Molecular Input Line Entry System) is a notation representing molecular structure.")
with col2:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Clear", key="clear_button"):
st.session_state.smiles_input = ""
# Input validation
is_valid = False
if smiles_input:
mol = Chem.MolFromSmiles(smiles_input)
is_valid = mol is not None
if is_valid:
st.session_state.smiles_input = smiles_input
col1, col2 = st.columns([1, 2])
with col1:
mol_img = get_molecule_image(smiles_input)
if mol_img:
st.markdown(f"<img src='data:image/png;base64,{mol_img}' style='max-width:100%;'>", unsafe_allow_html=True)
with col2:
st.markdown("<div class='info-box'>", unsafe_allow_html=True)
st.markdown("### Molecule Properties")
st.write(f"**Formula:** {Chem.rdMolDescriptors.CalcMolFormula(mol)}")
st.write(f"**Rings:** {Descriptors.RingCount(mol)}")
st.write(f"**H-Bond Donors:** {Descriptors.NumHDonors(mol)}")
st.write(f"**H-Bond Acceptors:** {Descriptors.NumHAcceptors(mol)}")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.warning("Invalid SMILES string. Please check your input.")
# Prediction button
run_prediction = st.button("π Predict Properties", disabled=not is_valid, key="predict_button")
if run_prediction:
try:
# Load resources with progress indication
progress_bar = st.progress(0)
status_text = st.empty()
# Step 1: Load models
status_text.text("Loading models...")
model = load_model()
tokenizer, chemberta = load_chemberta()
scalers = load_scalers()
progress_bar.progress(0.25) # Ensure float value between 0 and 1
time.sleep(0.5) # Simulate processing time for better UX
# Step 2: Compute molecular features
status_text.text("Computing molecular features...")
descriptors = compute_descriptors(smiles_input)
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
fingerprint = get_morgan_fingerprint(smiles_input)
fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32)
features = torch.cat([descriptors_tensor, fingerprint_tensor], dim=1).to(device)
progress_bar.progress(0.50) # Ensure float value between 0 and 1
time.sleep(0.5) # Simulate processing time
# Step 3: Generate embeddings
status_text.text("Generating ChemBERTa embeddings...")
embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
progress_bar.progress(0.75) # Ensure float value between 0 and 1
time.sleep(0.5) # Simulate processing time
# Step 4: Make predictions
status_text.text("Making predictions...")
with torch.no_grad():
preds = model(embedding, features)
preds_np = preds.cpu().numpy()
keys = list(scalers.keys())
preds_rescaled = np.concatenate([
scalers[keys[i]].inverse_transform(preds_np[:, [i]])
for i in range(6)
], axis=1)
results = {key: val for key, val in zip(keys, preds_rescaled.flatten())}
progress_bar.progress(1.0) # Ensure float value between 0 and 1
status_text.empty()
# Save to database
mol_img = get_molecule_image(smiles_input)
save_to_db(smiles_input, results, mol_img)
# Display results
st.success("β
Prediction completed successfully!")
# Visualize results
visualize_properties(results)
# Detailed results in expandable section
with st.expander("View Detailed Results"):
result_df = pd.DataFrame({
'Property': list(results.keys()),
'Predicted Value': [f"{val:.4f}" for val in results.values()]
})
st.table(result_df)
# Export options
csv = result_df.to_csv(index=False)
st.download_button(
label="Download Results as CSV",
data=csv,
file_name=f"polymer_prediction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"Prediction failed: {str(e)}")
st.code(str(e))
# Footer
st.markdown("""
<div style="text-align: center; margin-top: 3rem; padding-top: 1rem; border-top: 1px solid #ccc; color: #666;">
<p>Polymer Property Prediction Tool - Β© 2025</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
show() |