transpolymer commited on
Commit
ff928a7
·
verified ·
1 Parent(s): bba110f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -94
app.py CHANGED
@@ -1,99 +1,30 @@
1
  import streamlit as st
2
- import torch
3
- import numpy as np
4
- from transformers import AutoTokenizer, AutoModel
5
- from rdkit import Chem
6
- from rdkit.Chem import AllChem, Descriptors
7
- from torch import nn
8
- import pandas as pd
9
 
10
- # Model Setup
11
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
- chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
- chemberta.eval()
14
-
15
- # Define your model architecture
16
- class TransformerRegressor(nn.Module):
17
- def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
18
- super().__init__()
19
- self.feat_proj = nn.Linear(feat_dim, emb_dim)
20
- encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
21
- self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
22
- self.regression_head = nn.Sequential(
23
- nn.Linear(emb_dim, 256), nn.ReLU(),
24
- nn.Linear(256, 128), nn.ReLU(),
25
- nn.Linear(128, output_dim)
26
- )
27
-
28
- def forward(self, x, feat):
29
- feat_emb = self.feat_proj(feat)
30
- stacked = torch.stack([x, feat_emb], dim=1)
31
- encoded = self.transformer_encoder(stacked)
32
- aggregated = encoded.mean(dim=1)
33
- return self.regression_head(aggregated)
34
-
35
- # Load model
36
- model = TransformerRegressor()
37
- model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu')))
38
- model.eval()
39
-
40
- # Feature Functions
41
- descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
42
- Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
43
- Descriptors.NumHDonors, Descriptors.RingCount,
44
- Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
45
- Descriptors.NHOHCount]
46
-
47
- def fix_smiles(s):
48
- try:
49
- mol = Chem.MolFromSmiles(s.strip())
50
- if mol:
51
- return Chem.MolToSmiles(mol)
52
- except:
53
- return None
54
- return None
55
-
56
- def compute_features(smiles):
57
- mol = Chem.MolFromSmiles(smiles)
58
- if not mol:
59
- return [0]*10 + [0]*2048
60
- desc = [fn(mol) for fn in descriptor_fns]
61
- fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
62
- return desc + list(fp)
63
-
64
- def embed_smiles(smiles_list):
65
- inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
66
- outputs = chemberta(**inputs)
67
- return outputs.last_hidden_state[:, 0, :]
68
-
69
- # Streamlit UI
70
- st.set_page_config(page_title="TransPolymer", layout="centered")
71
- st.title("TransPolymer - Predict Polymer Properties")
72
-
73
- smiles_input = st.text_input("Enter SMILES Representation of Polymer")
74
-
75
- if st.button("Predict"):
76
- fixed = fix_smiles(smiles_input)
77
- if not fixed:
78
- st.error("Invalid SMILES string.")
79
  else:
80
- features = compute_features(fixed)
81
- features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
82
- embedding = embed_smiles([fixed])
83
-
84
- with torch.no_grad():
85
- pred = model(embedding, features_tensor)
86
- result = pred.numpy().flatten()
87
 
88
- properties = [
89
- "Tensile Strength",
90
- "Ionization Energy",
91
- "Electron Affinity",
92
- "logP",
93
- "Refractive Index",
94
- "Molecular Weight"
95
- ]
96
 
97
- st.success("Predicted Polymer Properties:")
98
- for prop, val in zip(properties, result):
99
- st.write(f"**{prop}**: {val:.4f}")
 
1
  import streamlit as st
 
 
 
 
 
 
 
2
 
3
+ # Import your page files
4
+ import home
5
+ import prediction
6
+ import about
7
+ import contact
8
+
9
+ # Set up the page configuration
10
+ st.set_page_config(page_title="TransPolymer", layout="wide")
11
+
12
+ # Set up navigation logic
13
+ def load_page(page):
14
+ if page == "Home":
15
+ home.show()
16
+ elif page == "Predictions":
17
+ prediction.show()
18
+ elif page == "About":
19
+ about.show()
20
+ elif page == "Contact":
21
+ contact.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  else:
23
+ st.error("Page not found")
 
 
 
 
 
 
24
 
25
+ # Navigation menu
26
+ st.sidebar.title("Navigation")
27
+ page = st.sidebar.radio("Select a Page", ["Home", "Predictions", "About", "Contact"])
 
 
 
 
 
28
 
29
+ # Call the function to display the selected page
30
+ load_page(page)