transpolymer commited on
Commit
1002edf
·
verified ·
1 Parent(s): 3dedab9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem, Descriptors
7
+ from torch import nn
8
+ import pandas as pd
9
+
10
+ # Model Setup
11
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
+ chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
+ chemberta.eval()
14
+
15
+ # Define your model architecture
16
+ class TransformerRegressor(nn.Module):
17
+ def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
18
+ super().__init__()
19
+ self.feat_proj = nn.Linear(feat_dim, emb_dim)
20
+ encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
21
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
22
+ self.regression_head = nn.Sequential(
23
+ nn.Linear(emb_dim, 256), nn.ReLU(),
24
+ nn.Linear(256, 128), nn.ReLU(),
25
+ nn.Linear(128, output_dim)
26
+ )
27
+
28
+ def forward(self, x, feat):
29
+ feat_emb = self.feat_proj(feat)
30
+ stacked = torch.stack([x, feat_emb], dim=1)
31
+ encoded = self.transformer_encoder(stacked)
32
+ aggregated = encoded.mean(dim=1)
33
+ return self.regression_head(aggregated)
34
+
35
+ # Load model
36
+ model = TransformerRegressor()
37
+ model.load_state_dict(torch.load("best_model.pt", map_location=torch.device('cpu')))
38
+ model.eval()
39
+
40
+ # Feature Functions
41
+ descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
42
+ Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
43
+ Descriptors.NumHDonors, Descriptors.RingCount,
44
+ Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
45
+ Descriptors.NHOHCount]
46
+
47
+ def fix_smiles(s):
48
+ try:
49
+ mol = Chem.MolFromSmiles(s.strip())
50
+ if mol:
51
+ return Chem.MolToSmiles(mol)
52
+ except:
53
+ return None
54
+ return None
55
+
56
+ def compute_features(smiles):
57
+ mol = Chem.MolFromSmiles(smiles)
58
+ if not mol:
59
+ return [0]*10 + [0]*2048
60
+ desc = [fn(mol) for fn in descriptor_fns]
61
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
62
+ return desc + list(fp)
63
+
64
+ def embed_smiles(smiles_list):
65
+ inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
66
+ outputs = chemberta(**inputs)
67
+ return outputs.last_hidden_state[:, 0, :]
68
+
69
+ # Streamlit UI
70
+ st.set_page_config(page_title="TransPolymer", layout="centered")
71
+ st.title("TransPolymer - Predict Polymer Properties")
72
+
73
+ smiles_input = st.text_input("Enter SMILES Representation of Polymer")
74
+
75
+ if st.button("Predict"):
76
+ fixed = fix_smiles(smiles_input)
77
+ if not fixed:
78
+ st.error("Invalid SMILES string.")
79
+ else:
80
+ features = compute_features(fixed)
81
+ features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
82
+ embedding = embed_smiles([fixed])
83
+
84
+ with torch.no_grad():
85
+ pred = model(embedding, features_tensor)
86
+ result = pred.numpy().flatten()
87
+
88
+ properties = [
89
+ "Tensile Strength",
90
+ "Ionization Energy",
91
+ "Electron Affinity",
92
+ "logP",
93
+ "Refractive Index",
94
+ "Molecular Weight"
95
+ ]
96
+
97
+ st.success("Predicted Polymer Properties:")
98
+ for prop, val in zip(properties, result):
99
+ st.write(f"**{prop}**: {val:.4f}")