File size: 5,357 Bytes
a48f0ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import argparse
import resource
from pathlib import Path
import pretrain
import seml
import torch
import vaetrain
from sacred import Experiment
from seml.utils import make_hash
ex = Experiment()
seml.setup_logger(ex)
@ex.post_run_hook
def collect_stats(_run):
seml.collect_exp_stats(_run)
@ex.config
def config():
overwrite = None
db_collection = None
if db_collection is not None:
ex.observers.append(
seml.create_mongodb_observer(db_collection, overwrite=overwrite)
)
class ExperimentWrapper:
def __init__(self, init_all=True):
pass
@ex.capture(prefix="training")
def train(
self,
training_path,
incl_zinc,
save_path,
batch_size,
hidden_size,
latent_size,
depth,
lr,
gamma,
max_epoch,
num_workers,
print_iter,
save_iter,
subsample_zinc_percent,
pretrain_only=True,
multip_share_strategy=None,
model_path=None,
beta=0.0,
vocab_path=None,
):
if multip_share_strategy:
torch.multiprocessing.set_sharing_strategy(multip_share_strategy)
# allow for more file descriptors open in parallel
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
# Construct the training file. If requested, also add all SMILES from ZINC
outpath = (
Path().cwd()
/ "data"
/ f"train_{seml.utils.make_hash(ex.current_run.config)}.txt"
)
zinc_f = Path().home() / ".dgl" / "jtvae" / "train.txt"
assert zinc_f.exists()
# truncates the outfile if it already exists
with open(outpath, "w") as outfile:
n_total_smiles = 0
if incl_zinc:
with open(zinc_f) as infile:
for i, line in enumerate(infile):
# subsampling the file
if i >= int(subsample_zinc_percent * 220011):
break
line = line.strip()
# skip the header and some weird 'Cl.[Li]' drug
if line != "smiles" and "Li" not in line:
n_total_smiles += 1
outfile.write(line + "\n")
with open(training_path) as infile:
for line in infile:
line = line.strip()
# skip the header and some weird 'Cl.[Li]' drug
if line != "smiles" and "Li" not in line:
n_total_smiles += 1
outfile.write(line + "\n")
print(f"Total SMILES: {n_total_smiles}, stored at {outpath.resolve()}")
if training_path:
assert Path(training_path).exists(), training_path
if pretrain_only:
args = argparse.Namespace(
**{
"train_path": str(outpath),
"save_path": save_path,
"batch_size": batch_size,
"hidden_size": hidden_size,
"latent_size": latent_size,
"depth": depth,
"lr": lr,
"gamma": gamma,
"max_epoch": max_epoch,
"num_workers": num_workers,
"print_iter": print_iter,
"save_iter": save_iter,
"use_cpu": False,
"hash": seml.utils.make_hash(ex.current_run.config),
}
)
results = pretrain.main(args)
else:
args = argparse.Namespace(
**{
"train_path": str(outpath),
"save_path": save_path,
"vocab_path": vocab_path,
"model_path": model_path,
"batch_size": batch_size,
"hidden_size": hidden_size,
"latent_size": latent_size,
"depth": depth,
"lr": lr,
"gamma": gamma,
"max_epoch": max_epoch,
"num_workers": num_workers,
"print_iter": print_iter,
"save_iter": save_iter,
"use_cpu": False,
"hash": seml.utils.make_hash(ex.current_run.config),
"beta": beta,
}
)
results = vaetrain.main(args)
return results
# We can call this command, e.g., from a Jupyter notebook with init_all=False to get an "empty" experiment wrapper,
# where we can then for instance load a pretrained model to inspect the performance.
@ex.command(unobserved=True)
def get_experiment(init_all=False):
print("get_experiment")
experiment = ExperimentWrapper(init_all=init_all)
return experiment
# This function will be called by default. Note that we could in principle manually pass an experiment instance,
# e.g., obtained by loading a model from the database or by calling this from a Jupyter notebook.
@ex.automain
def train(experiment=None):
if experiment is None:
experiment = ExperimentWrapper()
return experiment.train()
|