Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +8 -20
prediction.py
CHANGED
@@ -8,10 +8,9 @@ from rdkit import Chem
|
|
8 |
from rdkit.Chem import Descriptors
|
9 |
from rdkit.Chem import AllChem
|
10 |
from datetime import datetime
|
11 |
-
from db import get_database
|
12 |
import random
|
13 |
-
|
14 |
-
|
15 |
|
16 |
# ------------------------ Ensuring Deterministic Behavior ------------------------
|
17 |
random.seed(42)
|
@@ -20,7 +19,6 @@ torch.manual_seed(42)
|
|
20 |
torch.backends.cudnn.deterministic = True
|
21 |
torch.backends.cudnn.benchmark = False
|
22 |
|
23 |
-
# Check if CUDA is available for GPU acceleration
|
24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
|
26 |
# ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
|
@@ -29,7 +27,7 @@ def load_chemberta():
|
|
29 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
30 |
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
31 |
model.eval()
|
32 |
-
model.to(device)
|
33 |
return tokenizer, model
|
34 |
|
35 |
# ------------------------ Load Scalers ------------------------
|
@@ -51,7 +49,7 @@ class TransformerRegressor(nn.Module):
|
|
51 |
d_model=embedding_dim,
|
52 |
nhead=8,
|
53 |
dim_feedforward=ff_dim,
|
54 |
-
dropout=0.0,
|
55 |
batch_first=True
|
56 |
)
|
57 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
@@ -72,18 +70,15 @@ class TransformerRegressor(nn.Module):
|
|
72 |
# ------------------------ Load Model ------------------------
|
73 |
@st.cache_resource
|
74 |
def load_model():
|
75 |
-
# Initialize the model architecture first
|
76 |
model = TransformerRegressor()
|
77 |
-
|
78 |
-
# Load the state_dict (weights) from the saved model file
|
79 |
try:
|
80 |
-
|
|
|
81 |
model.load_state_dict(state_dict)
|
82 |
model.eval()
|
83 |
-
model.to(device)
|
84 |
except Exception as e:
|
85 |
raise ValueError(f"Failed to load model: {e}")
|
86 |
-
|
87 |
return model
|
88 |
|
89 |
# ------------------------ Descriptors ------------------------
|
@@ -91,7 +86,6 @@ def compute_descriptors(smiles: str):
|
|
91 |
mol = Chem.MolFromSmiles(smiles)
|
92 |
if mol is None:
|
93 |
raise ValueError("Invalid SMILES string.")
|
94 |
-
|
95 |
descriptors = [
|
96 |
Descriptors.MolWt(mol),
|
97 |
Descriptors.MolLogP(mol),
|
@@ -119,7 +113,7 @@ def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
|
|
119 |
inputs = tokenizer(smiles, return_tensors="pt")
|
120 |
with torch.no_grad():
|
121 |
outputs = chemberta(**inputs)
|
122 |
-
return outputs.last_hidden_state.mean(dim=1)
|
123 |
|
124 |
# ------------------------ Save to DB ------------------------
|
125 |
def save_to_db(smiles, predictions):
|
@@ -141,7 +135,6 @@ def show():
|
|
141 |
|
142 |
if st.button("Predict"):
|
143 |
try:
|
144 |
-
# Load the model
|
145 |
model = load_model()
|
146 |
|
147 |
mol = Chem.MolFromSmiles(smiles_input)
|
@@ -149,10 +142,8 @@ def show():
|
|
149 |
st.error("Invalid SMILES string.")
|
150 |
return
|
151 |
|
152 |
-
# Load the ChemBERTa tokenizer and model
|
153 |
tokenizer, chemberta = load_chemberta()
|
154 |
|
155 |
-
# Compute Descriptors, Fingerprints, and Embedding
|
156 |
descriptors = compute_descriptors(smiles_input)
|
157 |
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
|
158 |
|
@@ -161,7 +152,6 @@ def show():
|
|
161 |
|
162 |
embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
|
163 |
|
164 |
-
# Combine Inputs and Make Prediction
|
165 |
combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
|
166 |
combined = combined_input.unsqueeze(1)
|
167 |
|
@@ -171,7 +161,6 @@ def show():
|
|
171 |
preds_np = preds.numpy()
|
172 |
keys = list(scalers.keys())
|
173 |
|
174 |
-
# Rescale Predictions
|
175 |
preds_rescaled = np.concatenate([
|
176 |
scalers[keys[i]].inverse_transform(preds_np[:, [i]])
|
177 |
for i in range(6)
|
@@ -183,7 +172,6 @@ def show():
|
|
183 |
for key, val in results.items():
|
184 |
st.markdown(f"**{key}**: {val}")
|
185 |
|
186 |
-
# Save the results to the database
|
187 |
save_to_db(smiles_input, results)
|
188 |
|
189 |
except Exception as e:
|
|
|
8 |
from rdkit.Chem import Descriptors
|
9 |
from rdkit.Chem import AllChem
|
10 |
from datetime import datetime
|
11 |
+
from db import get_database
|
12 |
import random
|
13 |
+
import os # <-- Added for debugging file paths
|
|
|
14 |
|
15 |
# ------------------------ Ensuring Deterministic Behavior ------------------------
|
16 |
random.seed(42)
|
|
|
19 |
torch.backends.cudnn.deterministic = True
|
20 |
torch.backends.cudnn.benchmark = False
|
21 |
|
|
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
|
24 |
# ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
|
|
|
27 |
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
28 |
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
29 |
model.eval()
|
30 |
+
model.to(device)
|
31 |
return tokenizer, model
|
32 |
|
33 |
# ------------------------ Load Scalers ------------------------
|
|
|
49 |
d_model=embedding_dim,
|
50 |
nhead=8,
|
51 |
dim_feedforward=ff_dim,
|
52 |
+
dropout=0.0,
|
53 |
batch_first=True
|
54 |
)
|
55 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
|
70 |
# ------------------------ Load Model ------------------------
|
71 |
@st.cache_resource
|
72 |
def load_model():
|
|
|
73 |
model = TransformerRegressor()
|
|
|
|
|
74 |
try:
|
75 |
+
print("Files in working directory:", os.listdir()) # <-- Debug print
|
76 |
+
state_dict = torch.load("transformer_model.bin", map_location=device)
|
77 |
model.load_state_dict(state_dict)
|
78 |
model.eval()
|
79 |
+
model.to(device)
|
80 |
except Exception as e:
|
81 |
raise ValueError(f"Failed to load model: {e}")
|
|
|
82 |
return model
|
83 |
|
84 |
# ------------------------ Descriptors ------------------------
|
|
|
86 |
mol = Chem.MolFromSmiles(smiles)
|
87 |
if mol is None:
|
88 |
raise ValueError("Invalid SMILES string.")
|
|
|
89 |
descriptors = [
|
90 |
Descriptors.MolWt(mol),
|
91 |
Descriptors.MolLogP(mol),
|
|
|
113 |
inputs = tokenizer(smiles, return_tensors="pt")
|
114 |
with torch.no_grad():
|
115 |
outputs = chemberta(**inputs)
|
116 |
+
return outputs.last_hidden_state.mean(dim=1)
|
117 |
|
118 |
# ------------------------ Save to DB ------------------------
|
119 |
def save_to_db(smiles, predictions):
|
|
|
135 |
|
136 |
if st.button("Predict"):
|
137 |
try:
|
|
|
138 |
model = load_model()
|
139 |
|
140 |
mol = Chem.MolFromSmiles(smiles_input)
|
|
|
142 |
st.error("Invalid SMILES string.")
|
143 |
return
|
144 |
|
|
|
145 |
tokenizer, chemberta = load_chemberta()
|
146 |
|
|
|
147 |
descriptors = compute_descriptors(smiles_input)
|
148 |
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
|
149 |
|
|
|
152 |
|
153 |
embedding = get_chemberta_embedding(smiles_input, tokenizer, chemberta)
|
154 |
|
|
|
155 |
combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1)
|
156 |
combined = combined_input.unsqueeze(1)
|
157 |
|
|
|
161 |
preds_np = preds.numpy()
|
162 |
keys = list(scalers.keys())
|
163 |
|
|
|
164 |
preds_rescaled = np.concatenate([
|
165 |
scalers[keys[i]].inverse_transform(preds_np[:, [i]])
|
166 |
for i in range(6)
|
|
|
172 |
for key, val in results.items():
|
173 |
st.markdown(f"**{key}**: {val}")
|
174 |
|
|
|
175 |
save_to_db(smiles_input, results)
|
176 |
|
177 |
except Exception as e:
|