Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import joblib | |
from transformers import AutoTokenizer, AutoModel | |
from rdkit import Chem | |
from rdkit.Chem import Descriptors | |
from rdkit.Chem import AllChem | |
from datetime import datetime | |
from db import get_database # Ensure this module is available | |
import random | |
# ------------------------ Ensuring Deterministic Behavior ------------------------ | |
random.seed(42) | |
np.random.seed(42) | |
torch.manual_seed(42) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# Check if CUDA is available for GPU acceleration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ------------------------ Load ChemBERTa Model + Tokenizer ------------------------ | |
def load_chemberta(): | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
model.eval() | |
model.to(device) # Send model to GPU if available | |
return tokenizer, model | |
# ------------------------ Load Scalers ------------------------ | |
scalers = { | |
"Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"), | |
"Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"), | |
"Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"), | |
"logP": joblib.load("scaler_LogP.joblib"), | |
"Refractive Index": joblib.load("scaler_Refractive_Index.joblib"), | |
"Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib") | |
} | |
# ------------------------ Transformer Model ------------------------ | |
class TransformerRegressor(nn.Module): | |
def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6): | |
super().__init__() | |
self.feat_proj = nn.Linear(input_dim, embedding_dim) | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=embedding_dim, | |
nhead=8, | |
dim_feedforward=ff_dim, | |
dropout=0.0, # No dropout for consistency | |
batch_first=True | |
) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.regression_head = nn.Sequential( | |
nn.Linear(embedding_dim, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, output_dim) | |
) | |
def forward(self, x): | |
x = self.feat_proj(x) | |
x = self.transformer_encoder(x) | |
x = x.mean(dim=1) | |
return self.regression_head(x) | |
# ------------------------ Load Model ------------------------ | |
def load_model(): | |
# Initialize the model architecture first | |
model = TransformerRegressor() | |
# Load the state_dict (weights) from the saved model file | |
try: | |
state_dict = torch.load("transformer_model(1).bin", map_location=device) # Ensure loading on the correct device | |
model.load_state_dict(state_dict) | |
model.eval() | |
model.to(device) # Send model to GPU if available | |
except Exception as e: | |
raise ValueError(f"Failed to load model: {e}") | |
return model | |
# ------------------------ Descriptors ------------------------ | |
def compute_descriptors(smiles: str): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
raise ValueError("Invalid SMILES string.") | |
descriptors = [ | |
Descriptors.MolWt(mol), | |
Descriptors.MolLogP(mol), | |
Descriptors.TPSA(mol), | |
Descriptors.NumRotatableBonds(mol), | |
Descriptors.NumHDonors(mol), | |
Descriptors.NumHAcceptors(mol), | |
Descriptors.FractionCSP3(mol), | |
Descriptors.HeavyAtomCount(mol), | |
Descriptors.RingCount(mol), | |
Descriptors.MolMR(mol) | |
] | |
return np.array(descriptors, dtype=np.float32) | |
# ------------------------ Fingerprints ------------------------ | |
def get_morgan_fingerprint(smiles, radius=2, n_bits=1280): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
raise ValueError("Invalid SMILES string.") | |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits) | |
return np.array(fp, dtype=np.float32).reshape(1, -1) | |
# ------------------------ Embedding ------------------------ | |
def get_chemberta_embedding(smiles: str, tokenizer, chemberta): | |
inputs = tokenizer(smiles, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = chemberta(**inputs) | |
return outputs.last_hidden_state.mean(dim=1) # Use average instead of CLS token | |
# ------------------------ Save to DB ------------------------ | |
def save_to_db(smiles, predictions): | |
predictions_clean = {k: float(v) for k, v in predictions.items()} | |
doc = { | |
"smiles": smiles, | |
"predictions": predictions_clean, | |
"timestamp": datetime.now() | |
} | |
db = get_database() | |
db["polymer_predictions"].insert_one(doc) | |
# ------------------------ Streamlit App ------------------------ | |
def show(): | |
st.markdown("<h1 style='text-align: center; color: #4CAF50;'>π¬ Polymer Property Prediction</h1>", unsafe_allow_html=True) | |
st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True) | |
smiles_input = st.text_input("Enter SMILES Representation of Polymer") | |
if st.button("Predict"): | |
try: | |
# Load the model | |
model = load_model() | |
mol = Chem.MolFromSmiles(smiles_input) | |
if mol is None: | |
st.error("Invalid SMILES string.") | |
return | |
# Load the ChemBERTa tokenizer and model | |
tokenizer, chemberta = load_chemberta() | |
# Compute Descriptors, Fingerprints, and Embedding | |
descriptors = compute_descriptors(smiles_input) | |
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0) | |
fingerprint = get_morgan_fingerprint(smiles_input) | |
fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32) | |
embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta) | |
# Combine Inputs and Make Prediction | |
combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1) | |
combined = combined_input.unsqueeze(1) | |
with torch.no_grad(): | |
preds = model(combined) | |
preds_np = preds.numpy() | |
keys = list(scalers.keys()) | |
# Rescale Predictions | |
preds_rescaled = np.concatenate([ | |
scalers[keys[i]].inverse_transform(preds_np[:, [i]]) | |
for i in range(6) | |
], axis=1) | |
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())} | |
st.success("Predicted Properties:") | |
for key, val in results.items(): | |
st.markdown(f"**{key}**: {val}") | |
# Save the results to the database | |
save_to_db(smiles_input, results) | |
except Exception as e: | |
st.error(f"Prediction failed: {e}") | |