Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +17 -4
prediction.py
CHANGED
@@ -6,6 +6,7 @@ import joblib
|
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
from rdkit import Chem
|
8 |
from rdkit.Chem import Descriptors
|
|
|
9 |
from datetime import datetime
|
10 |
from db import get_database # This must be available in your repo
|
11 |
|
@@ -86,12 +87,20 @@ def compute_descriptors(smiles: str):
|
|
86 |
]
|
87 |
return np.array(descriptors, dtype=np.float32)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
# Embedding function
|
90 |
def get_chemberta_embedding(smiles: str):
|
91 |
inputs = tokenizer(smiles, return_tensors="pt")
|
92 |
with torch.no_grad():
|
93 |
outputs = chemberta(**inputs)
|
94 |
-
return outputs.last_hidden_state[:, 0, :] # CLS token
|
95 |
|
96 |
# Save prediction to MongoDB
|
97 |
def save_to_db(smiles, predictions):
|
@@ -119,11 +128,15 @@ def show():
|
|
119 |
return
|
120 |
|
121 |
descriptors = compute_descriptors(smiles_input)
|
122 |
-
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
|
|
|
|
|
|
|
123 |
|
124 |
-
embedding = get_chemberta_embedding(smiles_input)
|
125 |
|
126 |
-
|
|
|
127 |
|
128 |
with torch.no_grad():
|
129 |
preds = model(combined)
|
|
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
from rdkit import Chem
|
8 |
from rdkit.Chem import Descriptors
|
9 |
+
from rdkit.Chem import AllChem
|
10 |
from datetime import datetime
|
11 |
from db import get_database # This must be available in your repo
|
12 |
|
|
|
87 |
]
|
88 |
return np.array(descriptors, dtype=np.float32)
|
89 |
|
90 |
+
# Fingerprint computation
|
91 |
+
def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
|
92 |
+
mol = Chem.MolFromSmiles(smiles)
|
93 |
+
if mol is None:
|
94 |
+
raise ValueError("Invalid SMILES string.")
|
95 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
96 |
+
return np.array(fp, dtype=np.float32).reshape(1, -1) # (1, 1280)
|
97 |
+
|
98 |
# Embedding function
|
99 |
def get_chemberta_embedding(smiles: str):
|
100 |
inputs = tokenizer(smiles, return_tensors="pt")
|
101 |
with torch.no_grad():
|
102 |
outputs = chemberta(**inputs)
|
103 |
+
return outputs.last_hidden_state[:, 0, :] # CLS token (1, 768)
|
104 |
|
105 |
# Save prediction to MongoDB
|
106 |
def save_to_db(smiles, predictions):
|
|
|
128 |
return
|
129 |
|
130 |
descriptors = compute_descriptors(smiles_input)
|
131 |
+
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0) # (1, 10)
|
132 |
+
|
133 |
+
fingerprint = get_morgan_fingerprint(smiles_input) # (1, 1280)
|
134 |
+
fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32) # (1, 1280)
|
135 |
|
136 |
+
embedding = get_chemberta_embedding(smiles_input) # (1, 768)
|
137 |
|
138 |
+
combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1) # (1, 2058)
|
139 |
+
combined = combined_input.unsqueeze(1) # (1, 1, 2058)
|
140 |
|
141 |
with torch.no_grad():
|
142 |
preds = model(combined)
|