import streamlit as st import torch import numpy as np from transformers import AutoTokenizer, AutoModel from rdkit import Chem from rdkit.Chem import AllChem, Descriptors from torch import nn from datetime import datetime from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB # Load tokenizer and ChemBERTa model @st.cache_resource def load_chemberta(): tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") model.eval() return tokenizer, model tokenizer, chemberta = load_chemberta() # Define your model architecture class TransformerRegressor(nn.Module): def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): super().__init__() self.feat_proj = nn.Linear(feat_dim, emb_dim) encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.regression_head = nn.Sequential( nn.Linear(emb_dim, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, output_dim) ) def forward(self, x, feat): feat_emb = self.feat_proj(feat) stacked = torch.stack([x, feat_emb], dim=1) encoded = self.transformer_encoder(stacked) aggregated = encoded.mean(dim=1) return self.regression_head(aggregated) # Load your saved model @st.cache_resource def load_regression_model(): model = TransformerRegressor() state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu")) model.load_state_dict(state_dict) model.eval() return model model = load_regression_model() # Feature Functions descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors, Descriptors.NumHDonors, Descriptors.RingCount, Descriptors.FractionCSP3, Descriptors.HeavyAtomCount, Descriptors.NHOHCount] def fix_smiles(s): try: mol = Chem.MolFromSmiles(s.strip()) if mol: return Chem.MolToSmiles(mol) except: return None return None def compute_features(smiles): mol = Chem.MolFromSmiles(smiles) if not mol: return [0]*10 + [0]*2048 desc = [fn(mol) for fn in descriptor_fns] fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) return desc + list(fp) def embed_smiles(smiles_list): inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) outputs = chemberta(**inputs) return outputs.last_hidden_state[:, 0, :] # Function to save prediction to MongoDB def save_to_db(smiles, predictions): # Convert all prediction values to native Python float predictions_clean = {k: float(v) for k, v in predictions.items()} doc = { "smiles": smiles, "predictions": predictions_clean, "timestamp": datetime.now() } db = get_database() # Connect to MongoDB collection = db["polymer_predictions"] collection.insert_one(doc) # Prediction Page UI def show(): st.markdown("