Transpolymer2 / app.py
transpolymer's picture
Update app.py
588c57e verified
raw
history blame
3.47 kB
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from torch import nn
import pandas as pd
# Model Setup
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta.eval()
# Define your model architecture
class TransformerRegressor(nn.Module):
def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
super().__init__()
self.feat_proj = nn.Linear(feat_dim, emb_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.regression_head = nn.Sequential(
nn.Linear(emb_dim, 256), nn.ReLU(),
nn.Linear(256, 128), nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x, feat):
feat_emb = self.feat_proj(feat)
stacked = torch.stack([x, feat_emb], dim=1)
encoded = self.transformer_encoder(stacked)
aggregated = encoded.mean(dim=1)
return self.regression_head(aggregated)
# Load model
model = TransformerRegressor()
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu')))
model.eval()
# Feature Functions
descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
Descriptors.NumHDonors, Descriptors.RingCount,
Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
Descriptors.NHOHCount]
def fix_smiles(s):
try:
mol = Chem.MolFromSmiles(s.strip())
if mol:
return Chem.MolToSmiles(mol)
except:
return None
return None
def compute_features(smiles):
mol = Chem.MolFromSmiles(smiles)
if not mol:
return [0]*10 + [0]*2048
desc = [fn(mol) for fn in descriptor_fns]
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
return desc + list(fp)
def embed_smiles(smiles_list):
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
outputs = chemberta(**inputs)
return outputs.last_hidden_state[:, 0, :]
# Streamlit UI
st.set_page_config(page_title="TransPolymer", layout="centered")
st.title("TransPolymer - Predict Polymer Properties")
smiles_input = st.text_input("Enter SMILES Representation of Polymer")
if st.button("Predict"):
fixed = fix_smiles(smiles_input)
if not fixed:
st.error("Invalid SMILES string.")
else:
features = compute_features(fixed)
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
embedding = embed_smiles([fixed])
with torch.no_grad():
pred = model(embedding, features_tensor)
result = pred.numpy().flatten()
properties = [
"Tensile Strength",
"Ionization Energy",
"Electron Affinity",
"logP",
"Refractive Index",
"Molecular Weight"
]
st.success("Predicted Polymer Properties:")
for prop, val in zip(properties, result):
st.write(f"**{prop}**: {val:.4f}")