transpolymer commited on
Commit
af4f3b0
·
verified ·
1 Parent(s): 1aed2eb

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +17 -4
prediction.py CHANGED
@@ -6,6 +6,7 @@ import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
  from rdkit.Chem import Descriptors
 
9
  from datetime import datetime
10
  from db import get_database # This must be available in your repo
11
 
@@ -86,12 +87,20 @@ def compute_descriptors(smiles: str):
86
  ]
87
  return np.array(descriptors, dtype=np.float32)
88
 
 
 
 
 
 
 
 
 
89
  # Embedding function
90
  def get_chemberta_embedding(smiles: str):
91
  inputs = tokenizer(smiles, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = chemberta(**inputs)
94
- return outputs.last_hidden_state[:, 0, :] # CLS token
95
 
96
  # Save prediction to MongoDB
97
  def save_to_db(smiles, predictions):
@@ -119,11 +128,15 @@ def show():
119
  return
120
 
121
  descriptors = compute_descriptors(smiles_input)
122
- descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
 
 
 
123
 
124
- embedding = get_chemberta_embedding(smiles_input)
125
 
126
- combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # (1, 1, 2058)
 
127
 
128
  with torch.no_grad():
129
  preds = model(combined)
 
6
  from transformers import AutoTokenizer, AutoModel
7
  from rdkit import Chem
8
  from rdkit.Chem import Descriptors
9
+ from rdkit.Chem import AllChem
10
  from datetime import datetime
11
  from db import get_database # This must be available in your repo
12
 
 
87
  ]
88
  return np.array(descriptors, dtype=np.float32)
89
 
90
+ # Fingerprint computation
91
+ def get_morgan_fingerprint(smiles, radius=2, n_bits=1280):
92
+ mol = Chem.MolFromSmiles(smiles)
93
+ if mol is None:
94
+ raise ValueError("Invalid SMILES string.")
95
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
96
+ return np.array(fp, dtype=np.float32).reshape(1, -1) # (1, 1280)
97
+
98
  # Embedding function
99
  def get_chemberta_embedding(smiles: str):
100
  inputs = tokenizer(smiles, return_tensors="pt")
101
  with torch.no_grad():
102
  outputs = chemberta(**inputs)
103
+ return outputs.last_hidden_state[:, 0, :] # CLS token (1, 768)
104
 
105
  # Save prediction to MongoDB
106
  def save_to_db(smiles, predictions):
 
128
  return
129
 
130
  descriptors = compute_descriptors(smiles_input)
131
+ descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0) # (1, 10)
132
+
133
+ fingerprint = get_morgan_fingerprint(smiles_input) # (1, 1280)
134
+ fingerprint_tensor = torch.tensor(fingerprint, dtype=torch.float32) # (1, 1280)
135
 
136
+ embedding = get_chemberta_embedding(smiles_input) # (1, 768)
137
 
138
+ combined_input = torch.cat([embedding, descriptors_tensor, fingerprint_tensor], dim=1) # (1, 2058)
139
+ combined = combined_input.unsqueeze(1) # (1, 1, 2058)
140
 
141
  with torch.no_grad():
142
  preds = model(combined)