Spaces:
Running
Running
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModel | |
from rdkit import Chem | |
from rdkit.Chem import AllChem, Descriptors | |
from torch import nn | |
import pandas as pd | |
# Model Setup | |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
chemberta.eval() | |
# Define your model architecture | |
class TransformerRegressor(nn.Module): | |
def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): | |
super().__init__() | |
self.feat_proj = nn.Linear(feat_dim, emb_dim) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.regression_head = nn.Sequential( | |
nn.Linear(emb_dim, 256), nn.ReLU(), | |
nn.Linear(256, 128), nn.ReLU(), | |
nn.Linear(128, output_dim) | |
) | |
def forward(self, x, feat): | |
feat_emb = self.feat_proj(feat) | |
stacked = torch.stack([x, feat_emb], dim=1) | |
encoded = self.transformer_encoder(stacked) | |
aggregated = encoded.mean(dim=1) | |
return self.regression_head(aggregated) | |
# Load model | |
model = TransformerRegressor() | |
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu'))) | |
model.eval() | |
# Feature Functions | |
descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, | |
Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors, | |
Descriptors.NumHDonors, Descriptors.RingCount, | |
Descriptors.FractionCSP3, Descriptors.HeavyAtomCount, | |
Descriptors.NHOHCount] | |
def fix_smiles(s): | |
try: | |
mol = Chem.MolFromSmiles(s.strip()) | |
if mol: | |
return Chem.MolToSmiles(mol) | |
except: | |
return None | |
return None | |
def compute_features(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if not mol: | |
return [0]*10 + [0]*2048 | |
desc = [fn(mol) for fn in descriptor_fns] | |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) | |
return desc + list(fp) | |
def embed_smiles(smiles_list): | |
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
outputs = chemberta(**inputs) | |
return outputs.last_hidden_state[:, 0, :] | |
# Streamlit UI | |
st.set_page_config(page_title="TransPolymer", layout="centered") | |
st.title("TransPolymer - Predict Polymer Properties") | |
smiles_input = st.text_input("Enter SMILES Representation of Polymer") | |
if st.button("Predict"): | |
fixed = fix_smiles(smiles_input) | |
if not fixed: | |
st.error("Invalid SMILES string.") | |
else: | |
features = compute_features(fixed) | |
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0) | |
embedding = embed_smiles([fixed]) | |
with torch.no_grad(): | |
pred = model(embedding, features_tensor) | |
result = pred.numpy().flatten() | |
properties = [ | |
"Tensile Strength", | |
"Ionization Energy", | |
"Electron Affinity", | |
"logP", | |
"Refractive Index", | |
"Molecular Weight" | |
] | |
st.success("Predicted Polymer Properties:") | |
for prop, val in zip(properties, result): | |
st.write(f"**{prop}**: {val:.4f}") |