Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,99 +1,30 @@
|
|
1 |
import streamlit as st
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from transformers import AutoTokenizer, AutoModel
|
5 |
-
from rdkit import Chem
|
6 |
-
from rdkit.Chem import AllChem, Descriptors
|
7 |
-
from torch import nn
|
8 |
-
import pandas as pd
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
feat_emb = self.feat_proj(feat)
|
30 |
-
stacked = torch.stack([x, feat_emb], dim=1)
|
31 |
-
encoded = self.transformer_encoder(stacked)
|
32 |
-
aggregated = encoded.mean(dim=1)
|
33 |
-
return self.regression_head(aggregated)
|
34 |
-
|
35 |
-
# Load model
|
36 |
-
model = TransformerRegressor()
|
37 |
-
model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu')))
|
38 |
-
model.eval()
|
39 |
-
|
40 |
-
# Feature Functions
|
41 |
-
descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
|
42 |
-
Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
|
43 |
-
Descriptors.NumHDonors, Descriptors.RingCount,
|
44 |
-
Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
|
45 |
-
Descriptors.NHOHCount]
|
46 |
-
|
47 |
-
def fix_smiles(s):
|
48 |
-
try:
|
49 |
-
mol = Chem.MolFromSmiles(s.strip())
|
50 |
-
if mol:
|
51 |
-
return Chem.MolToSmiles(mol)
|
52 |
-
except:
|
53 |
-
return None
|
54 |
-
return None
|
55 |
-
|
56 |
-
def compute_features(smiles):
|
57 |
-
mol = Chem.MolFromSmiles(smiles)
|
58 |
-
if not mol:
|
59 |
-
return [0]*10 + [0]*2048
|
60 |
-
desc = [fn(mol) for fn in descriptor_fns]
|
61 |
-
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
|
62 |
-
return desc + list(fp)
|
63 |
-
|
64 |
-
def embed_smiles(smiles_list):
|
65 |
-
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
66 |
-
outputs = chemberta(**inputs)
|
67 |
-
return outputs.last_hidden_state[:, 0, :]
|
68 |
-
|
69 |
-
# Streamlit UI
|
70 |
-
st.set_page_config(page_title="TransPolymer", layout="centered")
|
71 |
-
st.title("TransPolymer - Predict Polymer Properties")
|
72 |
-
|
73 |
-
smiles_input = st.text_input("Enter SMILES Representation of Polymer")
|
74 |
-
|
75 |
-
if st.button("Predict"):
|
76 |
-
fixed = fix_smiles(smiles_input)
|
77 |
-
if not fixed:
|
78 |
-
st.error("Invalid SMILES string.")
|
79 |
else:
|
80 |
-
|
81 |
-
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
|
82 |
-
embedding = embed_smiles([fixed])
|
83 |
-
|
84 |
-
with torch.no_grad():
|
85 |
-
pred = model(embedding, features_tensor)
|
86 |
-
result = pred.numpy().flatten()
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
"Electron Affinity",
|
92 |
-
"logP",
|
93 |
-
"Refractive Index",
|
94 |
-
"Molecular Weight"
|
95 |
-
]
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
st.write(f"**{prop}**: {val:.4f}")
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
# Import your page files
|
4 |
+
import home
|
5 |
+
import prediction
|
6 |
+
import about
|
7 |
+
import contact
|
8 |
+
|
9 |
+
# Set up the page configuration
|
10 |
+
st.set_page_config(page_title="TransPolymer", layout="wide")
|
11 |
+
|
12 |
+
# Set up navigation logic
|
13 |
+
def load_page(page):
|
14 |
+
if page == "Home":
|
15 |
+
home.show()
|
16 |
+
elif page == "Predictions":
|
17 |
+
prediction.show()
|
18 |
+
elif page == "About":
|
19 |
+
about.show()
|
20 |
+
elif page == "Contact":
|
21 |
+
contact.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
else:
|
23 |
+
st.error("Page not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
# Navigation menu
|
26 |
+
st.sidebar.title("Navigation")
|
27 |
+
page = st.sidebar.radio("Select a Page", ["Home", "Predictions", "About", "Contact"])
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# Call the function to display the selected page
|
30 |
+
load_page(page)
|
|