import streamlit as st import torch import numpy as np from transformers import AutoTokenizer, AutoModel from rdkit import Chem from rdkit.Chem import AllChem, Descriptors from torch import nn import pandas as pd # Model Setup tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") chemberta.eval() # Define your model architecture class TransformerRegressor(nn.Module): def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): super().__init__() self.feat_proj = nn.Linear(feat_dim, emb_dim) encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.regression_head = nn.Sequential( nn.Linear(emb_dim, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, output_dim) ) def forward(self, x, feat): feat_emb = self.feat_proj(feat) stacked = torch.stack([x, feat_emb], dim=1) encoded = self.transformer_encoder(stacked) aggregated = encoded.mean(dim=1) return self.regression_head(aggregated) # Load model model = TransformerRegressor() model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu'))) model.eval() # Feature Functions descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors, Descriptors.NumHDonors, Descriptors.RingCount, Descriptors.FractionCSP3, Descriptors.HeavyAtomCount, Descriptors.NHOHCount] def fix_smiles(s): try: mol = Chem.MolFromSmiles(s.strip()) if mol: return Chem.MolToSmiles(mol) except: return None return None def compute_features(smiles): mol = Chem.MolFromSmiles(smiles) if not mol: return [0]*10 + [0]*2048 desc = [fn(mol) for fn in descriptor_fns] fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) return desc + list(fp) def embed_smiles(smiles_list): inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) outputs = chemberta(**inputs) return outputs.last_hidden_state[:, 0, :] # Streamlit UI st.set_page_config(page_title="TransPolymer", layout="centered") st.title("TransPolymer - Predict Polymer Properties") smiles_input = st.text_input("Enter SMILES Representation of Polymer") if st.button("Predict"): fixed = fix_smiles(smiles_input) if not fixed: st.error("Invalid SMILES string.") else: features = compute_features(fixed) features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0) embedding = embed_smiles([fixed]) with torch.no_grad(): pred = model(embedding, features_tensor) result = pred.numpy().flatten() properties = [ "Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight" ] st.success("Predicted Polymer Properties:") for prop, val in zip(properties, result): st.write(f"**{prop}**: {val:.4f}")