Spaces:
Running
Running
Update inference.py
Browse files- inference.py +45 -22
inference.py
CHANGED
|
@@ -136,7 +136,7 @@ class Inference(object):
|
|
| 136 |
"""Restore the trained generator and discriminator."""
|
| 137 |
print('Loading the model...')
|
| 138 |
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
| 139 |
-
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage
|
| 140 |
|
| 141 |
def inference(self):
|
| 142 |
# Load the trained generator.
|
|
@@ -170,7 +170,9 @@ class Inference(object):
|
|
| 170 |
uniqueness_calc = []
|
| 171 |
real_smiles_snn = []
|
| 172 |
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
val_counter = 0
|
| 175 |
none_counter = 0
|
| 176 |
|
|
@@ -179,6 +181,7 @@ class Inference(object):
|
|
| 179 |
pbar = tqdm(range(self.sample_num))
|
| 180 |
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
| 181 |
for i, data in enumerate(self.inf_loader):
|
|
|
|
| 182 |
val_counter += 1
|
| 183 |
# Preprocess dataset
|
| 184 |
_, a_tensor, x_tensor = load_molecules(
|
|
@@ -206,13 +209,14 @@ class Inference(object):
|
|
| 206 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
| 207 |
|
| 208 |
for molecules in inference_drugs:
|
| 209 |
-
|
| 210 |
-
|
| 211 |
|
| 212 |
for molecules in inference_drugs:
|
| 213 |
if molecules is not None:
|
| 214 |
-
molecules = molecules.replace("*", "C")
|
| 215 |
-
|
|
|
|
| 216 |
uniqueness_calc.append(molecules)
|
| 217 |
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
|
| 218 |
pbar.update(1)
|
|
@@ -223,21 +227,30 @@ class Inference(object):
|
|
| 223 |
if generation_number == self.sample_num or none_counter == self.sample_num:
|
| 224 |
break
|
| 225 |
|
|
|
|
|
|
|
| 226 |
if not self.disable_correction:
|
| 227 |
-
|
| 228 |
-
gen_smi =
|
|
|
|
| 229 |
else:
|
| 230 |
-
gen_smi =
|
| 231 |
-
|
|
|
|
| 232 |
et = time.time() - start_time
|
| 233 |
|
| 234 |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
| 235 |
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
|
|
|
| 236 |
|
|
|
|
|
|
|
| 237 |
if not self.disable_correction:
|
| 238 |
val = round(len(gen_smi)/self.sample_num, 3)
|
|
|
|
| 239 |
else:
|
| 240 |
val = round(fraction_valid(gen_smi), 3)
|
|
|
|
| 241 |
|
| 242 |
uniq = round(fraction_unique(gen_smi), 3)
|
| 243 |
nov = round(novelty(gen_smi, chembl_smiles), 3)
|
|
@@ -251,23 +264,33 @@ class Inference(object):
|
|
| 251 |
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
| 252 |
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
|
| 255 |
"uniqueness": [uniq], "novelty": [nov],
|
| 256 |
"novelty_test": [nov_test], "drug_novelty": [drug_nov],
|
| 257 |
"max_len": [max_len], "mean_atom_type": [mean_atom],
|
| 258 |
"snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
|
| 259 |
"IntDiv": [int_div], "qed": [qed], "sa": [sa]})
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
return model_res
|
| 269 |
-
|
| 270 |
-
|
| 271 |
if __name__=="__main__":
|
| 272 |
parser = argparse.ArgumentParser()
|
| 273 |
|
|
@@ -300,4 +323,4 @@ if __name__=="__main__":
|
|
| 300 |
|
| 301 |
config = parser.parse_args()
|
| 302 |
inference = Inference(config)
|
| 303 |
-
inference.inference()
|
|
|
|
| 136 |
"""Restore the trained generator and discriminator."""
|
| 137 |
print('Loading the model...')
|
| 138 |
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
| 139 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 140 |
|
| 141 |
def inference(self):
|
| 142 |
# Load the trained generator.
|
|
|
|
| 170 |
uniqueness_calc = []
|
| 171 |
real_smiles_snn = []
|
| 172 |
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
|
| 173 |
+
f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
|
| 174 |
+
f.write("SMILES")
|
| 175 |
+
f.write("\n")
|
| 176 |
val_counter = 0
|
| 177 |
none_counter = 0
|
| 178 |
|
|
|
|
| 181 |
pbar = tqdm(range(self.sample_num))
|
| 182 |
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
| 183 |
for i, data in enumerate(self.inf_loader):
|
| 184 |
+
|
| 185 |
val_counter += 1
|
| 186 |
# Preprocess dataset
|
| 187 |
_, a_tensor, x_tensor = load_molecules(
|
|
|
|
| 209 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
| 210 |
|
| 211 |
for molecules in inference_drugs:
|
| 212 |
+
if molecules is None:
|
| 213 |
+
none_counter += 1
|
| 214 |
|
| 215 |
for molecules in inference_drugs:
|
| 216 |
if molecules is not None:
|
| 217 |
+
molecules = molecules.replace("*", "C")
|
| 218 |
+
f.write(molecules)
|
| 219 |
+
f.write("\n")
|
| 220 |
uniqueness_calc.append(molecules)
|
| 221 |
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
|
| 222 |
pbar.update(1)
|
|
|
|
| 227 |
if generation_number == self.sample_num or none_counter == self.sample_num:
|
| 228 |
break
|
| 229 |
|
| 230 |
+
f.close()
|
| 231 |
+
print("Inference completed, starting metrics calculation.")
|
| 232 |
if not self.disable_correction:
|
| 233 |
+
corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
|
| 234 |
+
gen_smi = corrected["SMILES"].tolist()
|
| 235 |
+
|
| 236 |
else:
|
| 237 |
+
gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
et = time.time() - start_time
|
| 241 |
|
| 242 |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
| 243 |
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
| 244 |
+
print("Inference mode is lasted for {:.2f} seconds".format(et))
|
| 245 |
|
| 246 |
+
print("Metrics calculation started using MOSES.")
|
| 247 |
+
|
| 248 |
if not self.disable_correction:
|
| 249 |
val = round(len(gen_smi)/self.sample_num, 3)
|
| 250 |
+
print("Validity: ", val, "\n")
|
| 251 |
else:
|
| 252 |
val = round(fraction_valid(gen_smi), 3)
|
| 253 |
+
print("Validity: ", val, "\n")
|
| 254 |
|
| 255 |
uniq = round(fraction_unique(gen_smi), 3)
|
| 256 |
nov = round(novelty(gen_smi, chembl_smiles), 3)
|
|
|
|
| 264 |
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
| 265 |
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
| 266 |
|
| 267 |
+
print("Uniqueness: ", uniq, "\n")
|
| 268 |
+
print("Novelty: ", nov, "\n")
|
| 269 |
+
print("Novelty_test: ", nov_test, "\n")
|
| 270 |
+
print("Drug_novelty: ", drug_nov, "\n")
|
| 271 |
+
print("max_len: ", max_len, "\n")
|
| 272 |
+
print("mean_atom_type: ", mean_atom, "\n")
|
| 273 |
+
print("snn_chembl: ", snn_chembl, "\n")
|
| 274 |
+
print("snn_drug: ", snn_drug, "\n")
|
| 275 |
+
print("IntDiv: ", int_div, "\n")
|
| 276 |
+
print("QED: ", qed, "\n")
|
| 277 |
+
print("SA: ", sa, "\n")
|
| 278 |
+
|
| 279 |
+
print("Metrics are calculated.")
|
| 280 |
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
|
| 281 |
"uniqueness": [uniq], "novelty": [nov],
|
| 282 |
"novelty_test": [nov_test], "drug_novelty": [drug_nov],
|
| 283 |
"max_len": [max_len], "mean_atom_type": [mean_atom],
|
| 284 |
"snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
|
| 285 |
"IntDiv": [int_div], "qed": [qed], "sa": [sa]})
|
| 286 |
+
search_res = pd.concat([search_res, model_res], axis=0)
|
| 287 |
+
os.remove("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
|
| 288 |
+
search_res.to_csv("experiments/inference/{}/inference_results.csv".format(self.submodel), index=False)
|
| 289 |
+
generatedsmiles = pd.DataFrame({"SMILES": gen_smi})
|
| 290 |
+
generatedsmiles.to_csv("experiments/inference/{}/inference_drugs.csv".format(self.submodel), index=False)
|
| 291 |
+
|
| 292 |
+
return model_res
|
| 293 |
+
|
|
|
|
|
|
|
|
|
|
| 294 |
if __name__=="__main__":
|
| 295 |
parser = argparse.ArgumentParser()
|
| 296 |
|
|
|
|
| 323 |
|
| 324 |
config = parser.parse_args()
|
| 325 |
inference = Inference(config)
|
| 326 |
+
inference.inference()
|