File size: 4,755 Bytes
84dad8f
 
 
 
 
 
 
dd45972
84dad8f
 
db3b1f2
dd45972
 
 
 
 
 
 
 
84dad8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd45972
 
 
 
 
 
 
 
 
 
84dad8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833f4a5
 
 
 
84dad8f
833f4a5
 
 
84dad8f
833f4a5
dd45972
833f4a5
84dad8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd45972
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from torch import nn
from datetime import datetime
from db import get_database  # Assuming you have a file db.py with get_database function to connect to MongoDB

# Load tokenizer and ChemBERTa model
@st.cache_resource
def load_chemberta():
    tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model.eval()
    return tokenizer, model

tokenizer, chemberta = load_chemberta()

# Define your model architecture
class TransformerRegressor(nn.Module):
    def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2):
        super().__init__()
        self.feat_proj = nn.Linear(feat_dim, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.regression_head = nn.Sequential(
            nn.Linear(emb_dim, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x, feat):
        feat_emb = self.feat_proj(feat)
        stacked = torch.stack([x, feat_emb], dim=1)
        encoded = self.transformer_encoder(stacked)
        aggregated = encoded.mean(dim=1)
        return self.regression_head(aggregated)

# Load your saved model
@st.cache_resource
def load_regression_model():
    model = TransformerRegressor()
    state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)
    model.eval()
    return model

model = load_regression_model()

# Feature Functions
descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
                  Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors,
                  Descriptors.NumHDonors, Descriptors.RingCount,
                  Descriptors.FractionCSP3, Descriptors.HeavyAtomCount,
                  Descriptors.NHOHCount]

def fix_smiles(s):
    try:
        mol = Chem.MolFromSmiles(s.strip())
        if mol:
            return Chem.MolToSmiles(mol)
    except:
        return None
    return None

def compute_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return [0]*10 + [0]*2048
    desc = [fn(mol) for fn in descriptor_fns]
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
    return desc + list(fp)

def embed_smiles(smiles_list):
    inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128)
    outputs = chemberta(**inputs)
    return outputs.last_hidden_state[:, 0, :]

# Function to save prediction to MongoDB
def save_to_db(smiles, predictions):
    # Convert all prediction values to native Python float
    predictions_clean = {k: float(v) for k, v in predictions.items()}
    
    doc = {
        "smiles": smiles,
        "predictions": predictions_clean,
        "timestamp": datetime.now()
    }

    db = get_database()  # Connect to MongoDB
    collection = db["polymer_predictions"]
    collection.insert_one(doc)

# Prediction Page UI
def show():
    st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
    st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True)

    smiles_input = st.text_input("Enter SMILES Representation of Polymer")

    if st.button("Predict"):
        fixed = fix_smiles(smiles_input)
        if not fixed:
            st.error("Invalid SMILES string.")
        else:
            features = compute_features(fixed)
            features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
            embedding = embed_smiles([fixed])

            with torch.no_grad():
                pred = model(embedding, features_tensor)
                result = pred.numpy().flatten()

            properties = [
                "Tensile Strength",
                "Ionization Energy",
                "Electron Affinity",
                "logP",
                "Refractive Index",
                "Molecular Weight"
            ]

            predictions = {}
            st.success("Predicted Polymer Properties:")
            for prop, val in zip(properties, result):
                st.write(f"**{prop}**: {val:.4f}")
                predictions[prop] = val

            # Save the prediction to MongoDB
            save_to_db(smiles_input, predictions)
            st.success("Prediction saved successfully!")