transpolymer commited on
Commit
84dad8f
Β·
verified Β·
1 Parent(s): ff928a7

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +126 -73
prediction.py CHANGED
@@ -1,73 +1,126 @@
1
- import streamlit as st
2
- import requests
3
- from rdkit import Chem
4
- import datetime
5
- from db import get_database
6
-
7
- # Function to validate SMILES string
8
- def is_valid_smiles(smiles):
9
- """ Validate if the input is a valid SMILES string using RDKit """
10
- mol = Chem.MolFromSmiles(smiles)
11
- return mol is not None
12
-
13
- # Function to save prediction to MongoDB
14
- def save_to_db(smiles_input, predictions):
15
- db = get_database()
16
- collection = db["polymers"] # your collection
17
- doc = {
18
- "smiles": smiles_input,
19
- "predictions": predictions,
20
- "timestamp": datetime.datetime.utcnow()
21
- }
22
- collection.insert_one(doc)
23
-
24
- # Streamlit page for Polymer Property Prediction
25
- def show():
26
- st.markdown("<h1 style='text-align: center; color: #4CAF50;'>πŸ”¬ Polymer Property Prediction</h1>", unsafe_allow_html=True)
27
- st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
28
-
29
- # Input box with placeholder inside
30
- input_text = st.text_input(
31
- label="",
32
- placeholder="Enter SMILES ",
33
- key="smiles_input"
34
- )
35
-
36
- # Show Predict button always
37
- predict_clicked = st.button("πŸš€ Predict", use_container_width=True)
38
-
39
- # Predict on button click OR on pressing Enter with input
40
- if (predict_clicked or input_text) and input_text.strip():
41
- with st.spinner("Predicting..."):
42
- if is_valid_smiles(input_text.strip()):
43
- try:
44
- input_data = {
45
- "smiles": input_text.strip() # Only sending smiles
46
- }
47
- response = requests.post("http://127.0.0.1:8000/predict", json=input_data)
48
- if response.status_code == 200:
49
- result = response.json()
50
- renamed_properties = {
51
- "property1": "Tensile_strength (Mpa)",
52
- "property2": "Ionization_Energy (eV)",
53
- "property3": "Electron_Affinity (eV)",
54
- "property4": "LogP",
55
- "property5": "Refractive_Index",
56
- "property6": "Molecular_Weight (g/mol)"
57
- }
58
- predictions = {}
59
- for key, name in renamed_properties.items():
60
- value = result.get(key, 'N/A')
61
- st.markdown(f"<div style='font-size:18px; padding: 6px 0;'><strong>{name}:</strong> {value}</div>", unsafe_allow_html=True)
62
- predictions[name] = value
63
-
64
- # Save prediction to MongoDB
65
- save_to_db(input_text.strip(), predictions)
66
-
67
- st.success("Prediction saved successfully!")
68
- else:
69
- st.error("Prediction failed. Please try again.")
70
- except Exception as e:
71
- st.error(f"Error: {e}")
72
- else:
73
- st.error("❌ Invalid SMILES input. Please enter a correct SMILES string.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem, Descriptors
7
+ from torch import nn
8
+ import pandas as pd
9
+ import requests
10
+ import datetime
11
+ from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB
12
+
13
+ # Model Setup
14
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
15
+ chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
16
+ chemberta.eval()
17
+
18
+ # Define your model architecture
19
+ class TransformerRegressor(nn.Module):
20
+ def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
21
+ super().__init__()
22
+ self.feat_proj = nn.Linear(feat_dim, emb_dim)
23
+ encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
24
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
25
+ self.regression_head = nn.Sequential(
26
+ nn.Linear(emb_dim, 256), nn.ReLU(),
27
+ nn.Linear(256, 128), nn.ReLU(),
28
+ nn.Linear(128, output_dim)
29
+ )
30
+
31
+ def forward(self, x, feat):
32
+ feat_emb = self.feat_proj(feat)
33
+ stacked = torch.stack([x, feat_emb], dim=1)
34
+ encoded = self.transformer_encoder(stacked)
35
+ aggregated = encoded.mean(dim=1)
36
+ return self.regression_head(aggregated)
37
+
38
+ # Load model
39
+ model = TransformerRegressor()
40
+ model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu')))
41
+ model.eval()
42
+
43
+ # Feature Functions
44
+ descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
45
+ Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
46
+ Descriptors.NumHDonors, Descriptors.RingCount,
47
+ Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
48
+ Descriptors.NHOHCount]
49
+
50
+ def fix_smiles(s):
51
+ try:
52
+ mol = Chem.MolFromSmiles(s.strip())
53
+ if mol:
54
+ return Chem.MolToSmiles(mol)
55
+ except:
56
+ return None
57
+ return None
58
+
59
+ def compute_features(smiles):
60
+ mol = Chem.MolFromSmiles(smiles)
61
+ if not mol:
62
+ return [0]*10 + [0]*2048
63
+ desc = [fn(mol) for fn in descriptor_fns]
64
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
65
+ return desc + list(fp)
66
+
67
+ def embed_smiles(smiles_list):
68
+ inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
69
+ outputs = chemberta(**inputs)
70
+ return outputs.last_hidden_state[:, 0, :]
71
+
72
+ # Function to validate SMILES string
73
+ def is_valid_smiles(smiles):
74
+ """ Validate if the input is a valid SMILES string using RDKit """
75
+ mol = Chem.MolFromSmiles(smiles)
76
+ return mol is not None
77
+
78
+ # Function to save prediction to MongoDB
79
+ def save_to_db(smiles_input, predictions):
80
+ db = get_database()
81
+ collection = db["polymers"] # your collection
82
+ doc = {
83
+ "smiles": smiles_input,
84
+ "predictions": predictions,
85
+ "timestamp": datetime.datetime.utcnow()
86
+ }
87
+ collection.insert_one(doc)
88
+
89
+ # Prediction Page UI
90
+ def show():
91
+ st.markdown("<h1 style='text-align: center; color: #4CAF50;'>πŸ”¬ Polymer Property Prediction</h1>", unsafe_allow_html=True)
92
+ st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)
93
+
94
+ smiles_input = st.text_input("Enter SMILES Representation of Polymer")
95
+
96
+ if st.button("Predict"):
97
+ fixed = fix_smiles(smiles_input)
98
+ if not fixed:
99
+ st.error("Invalid SMILES string.")
100
+ else:
101
+ features = compute_features(fixed)
102
+ features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
103
+ embedding = embed_smiles([fixed])
104
+
105
+ with torch.no_grad():
106
+ pred = model(embedding, features_tensor)
107
+ result = pred.numpy().flatten()
108
+
109
+ properties = [
110
+ "Tensile Strength",
111
+ "Ionization Energy",
112
+ "Electron Affinity",
113
+ "logP",
114
+ "Refractive Index",
115
+ "Molecular Weight"
116
+ ]
117
+
118
+ predictions = {}
119
+ st.success("Predicted Polymer Properties:")
120
+ for prop, val in zip(properties, result):
121
+ st.write(f"**{prop}**: {val:.4f}")
122
+ predictions[prop] = val
123
+
124
+ # Save the prediction to MongoDB
125
+ save_to_db(smiles_input, predictions)
126
+ st.success("Prediction saved successfully!")