Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModel | |
from rdkit import Chem | |
from rdkit.Chem import AllChem, Descriptors | |
from torch import nn | |
from datetime import datetime | |
from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB | |
# Load tokenizer and ChemBERTa model | |
def load_chemberta(): | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
model.eval() | |
return tokenizer, model | |
tokenizer, chemberta = load_chemberta() | |
# Define your model architecture | |
class TransformerRegressor(nn.Module): | |
def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): | |
super().__init__() | |
self.feat_proj = nn.Linear(feat_dim, emb_dim) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.regression_head = nn.Sequential( | |
nn.Linear(emb_dim, 256), nn.ReLU(), | |
nn.Linear(256, 128), nn.ReLU(), | |
nn.Linear(128, output_dim) | |
) | |
def forward(self, x, feat): | |
feat_emb = self.feat_proj(feat) | |
stacked = torch.stack([x, feat_emb], dim=1) | |
encoded = self.transformer_encoder(stacked) | |
aggregated = encoded.mean(dim=1) | |
return self.regression_head(aggregated) | |
# Load your saved model | |
def load_regression_model(): | |
model = TransformerRegressor() | |
state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu")) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
model = load_regression_model() | |
# Feature Functions | |
descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, | |
Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors, | |
Descriptors.NumHDonors, Descriptors.RingCount, | |
Descriptors.FractionCSP3, Descriptors.HeavyAtomCount, | |
Descriptors.NHOHCount] | |
def fix_smiles(s): | |
try: | |
mol = Chem.MolFromSmiles(s.strip()) | |
if mol: | |
return Chem.MolToSmiles(mol) | |
except: | |
return None | |
return None | |
def compute_features(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if not mol: | |
return [0]*10 + [0]*2048 | |
desc = [fn(mol) for fn in descriptor_fns] | |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) | |
return desc + list(fp) | |
def embed_smiles(smiles_list): | |
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
outputs = chemberta(**inputs) | |
return outputs.last_hidden_state[:, 0, :] | |
# Function to save prediction to MongoDB | |
def save_to_db(smiles, predictions): | |
# Convert all prediction values to native Python float | |
predictions_clean = {k: float(v) for k, v in predictions.items()} | |
doc = { | |
"smiles": smiles, | |
"predictions": predictions_clean, | |
"timestamp": datetime.now() | |
} | |
db = get_database() # Connect to MongoDB | |
collection = db["polymer_predictions"] | |
collection.insert_one(doc) | |
# Prediction Page UI | |
def show(): | |
st.markdown("<h1 style='text-align: center; color: #4CAF50;'>π¬ Polymer Property Prediction</h1>", unsafe_allow_html=True) | |
st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True) | |
smiles_input = st.text_input("Enter SMILES Representation of Polymer") | |
if st.button("Predict"): | |
fixed = fix_smiles(smiles_input) | |
if not fixed: | |
st.error("Invalid SMILES string.") | |
else: | |
features = compute_features(fixed) | |
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0) | |
embedding = embed_smiles([fixed]) | |
with torch.no_grad(): | |
pred = model(embedding, features_tensor) | |
result = pred.numpy().flatten() | |
properties = [ | |
"Tensile Strength", | |
"Ionization Energy", | |
"Electron Affinity", | |
"logP", | |
"Refractive Index", | |
"Molecular Weight" | |
] | |
predictions = {} | |
st.success("Predicted Polymer Properties:") | |
for prop, val in zip(properties, result): | |
st.write(f"**{prop}**: {val:.4f}") | |
predictions[prop] = val | |
# Save the prediction to MongoDB | |
save_to_db(smiles_input, predictions) | |
st.success("Prediction saved successfully!") |