Transpolymer2 / prediction.py
transpolymer's picture
Update prediction.py
db3b1f2 verified
raw
history blame
4.76 kB
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from torch import nn
from datetime import datetime
from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB
# Load tokenizer and ChemBERTa model
@st.cache_resource
def load_chemberta():
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model.eval()
return tokenizer, model
tokenizer, chemberta = load_chemberta()
# Define your model architecture
class TransformerRegressor(nn.Module):
def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
super().__init__()
self.feat_proj = nn.Linear(feat_dim, emb_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(emb_dim, 256), nn.ReLU(),
nn.Linear(256, 128), nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x, feat):
feat_emb = self.feat_proj(feat)
stacked = torch.stack([x, feat_emb], dim=1)
encoded = self.transformer_encoder(stacked)
aggregated = encoded.mean(dim=1)
return self.regression_head(aggregated)
# Load your saved model
@st.cache_resource
def load_regression_model():
model = TransformerRegressor()
state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()
return model
model = load_regression_model()
# Feature Functions
descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
Descriptors.NumHDonors, Descriptors.RingCount,
Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
Descriptors.NHOHCount]
def fix_smiles(s):
try:
mol = Chem.MolFromSmiles(s.strip())
if mol:
return Chem.MolToSmiles(mol)
except:
return None
return None
def compute_features(smiles):
mol = Chem.MolFromSmiles(smiles)
if not mol:
return [0]*10 + [0]*2048
desc = [fn(mol) for fn in descriptor_fns]
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
return desc + list(fp)
def embed_smiles(smiles_list):
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
outputs = chemberta(**inputs)
return outputs.last_hidden_state[:, 0, :]
# Function to save prediction to MongoDB
def save_to_db(smiles, predictions):
# Convert all prediction values to native Python float
predictions_clean = {k: float(v) for k, v in predictions.items()}
doc = {
"smiles": smiles,
"predictions": predictions_clean,
"timestamp": datetime.now()
}
db = get_database() # Connect to MongoDB
collection = db["polymer_predictions"]
collection.insert_one(doc)
# Prediction Page UI
def show():
st.markdown("<h1 style='text-align: center; color: #4CAF50;'>πŸ”¬ Polymer Property Prediction</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
smiles_input = st.text_input("Enter SMILES Representation of Polymer")
if st.button("Predict"):
fixed = fix_smiles(smiles_input)
if not fixed:
st.error("Invalid SMILES string.")
else:
features = compute_features(fixed)
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
embedding = embed_smiles([fixed])
with torch.no_grad():
pred = model(embedding, features_tensor)
result = pred.numpy().flatten()
properties = [
"Tensile Strength",
"Ionization Energy",
"Electron Affinity",
"logP",
"Refractive Index",
"Molecular Weight"
]
predictions = {}
st.success("Predicted Polymer Properties:")
for prop, val in zip(properties, result):
st.write(f"**{prop}**: {val:.4f}")
predictions[prop] = val
# Save the prediction to MongoDB
save_to_db(smiles_input, predictions)
st.success("Prediction saved successfully!")