|
import argparse |
|
import resource |
|
from pathlib import Path |
|
|
|
import pretrain |
|
import seml |
|
import torch |
|
import vaetrain |
|
from sacred import Experiment |
|
from seml.utils import make_hash |
|
|
|
ex = Experiment() |
|
seml.setup_logger(ex) |
|
|
|
|
|
@ex.post_run_hook |
|
def collect_stats(_run): |
|
seml.collect_exp_stats(_run) |
|
|
|
|
|
@ex.config |
|
def config(): |
|
overwrite = None |
|
db_collection = None |
|
if db_collection is not None: |
|
ex.observers.append( |
|
seml.create_mongodb_observer(db_collection, overwrite=overwrite) |
|
) |
|
|
|
|
|
class ExperimentWrapper: |
|
def __init__(self, init_all=True): |
|
pass |
|
|
|
@ex.capture(prefix="training") |
|
def train( |
|
self, |
|
training_path, |
|
incl_zinc, |
|
save_path, |
|
batch_size, |
|
hidden_size, |
|
latent_size, |
|
depth, |
|
lr, |
|
gamma, |
|
max_epoch, |
|
num_workers, |
|
print_iter, |
|
save_iter, |
|
subsample_zinc_percent, |
|
pretrain_only=True, |
|
multip_share_strategy=None, |
|
model_path=None, |
|
beta=0.0, |
|
vocab_path=None, |
|
): |
|
if multip_share_strategy: |
|
torch.multiprocessing.set_sharing_strategy(multip_share_strategy) |
|
|
|
|
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
|
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) |
|
|
|
|
|
outpath = ( |
|
Path().cwd() |
|
/ "data" |
|
/ f"train_{seml.utils.make_hash(ex.current_run.config)}.txt" |
|
) |
|
zinc_f = Path().home() / ".dgl" / "jtvae" / "train.txt" |
|
assert zinc_f.exists() |
|
|
|
|
|
with open(outpath, "w") as outfile: |
|
n_total_smiles = 0 |
|
if incl_zinc: |
|
with open(zinc_f) as infile: |
|
for i, line in enumerate(infile): |
|
|
|
if i >= int(subsample_zinc_percent * 220011): |
|
break |
|
line = line.strip() |
|
|
|
if line != "smiles" and "Li" not in line: |
|
n_total_smiles += 1 |
|
outfile.write(line + "\n") |
|
|
|
with open(training_path) as infile: |
|
for line in infile: |
|
line = line.strip() |
|
|
|
if line != "smiles" and "Li" not in line: |
|
n_total_smiles += 1 |
|
outfile.write(line + "\n") |
|
print(f"Total SMILES: {n_total_smiles}, stored at {outpath.resolve()}") |
|
|
|
if training_path: |
|
assert Path(training_path).exists(), training_path |
|
|
|
if pretrain_only: |
|
args = argparse.Namespace( |
|
**{ |
|
"train_path": str(outpath), |
|
"save_path": save_path, |
|
"batch_size": batch_size, |
|
"hidden_size": hidden_size, |
|
"latent_size": latent_size, |
|
"depth": depth, |
|
"lr": lr, |
|
"gamma": gamma, |
|
"max_epoch": max_epoch, |
|
"num_workers": num_workers, |
|
"print_iter": print_iter, |
|
"save_iter": save_iter, |
|
"use_cpu": False, |
|
"hash": seml.utils.make_hash(ex.current_run.config), |
|
} |
|
) |
|
results = pretrain.main(args) |
|
else: |
|
args = argparse.Namespace( |
|
**{ |
|
"train_path": str(outpath), |
|
"save_path": save_path, |
|
"vocab_path": vocab_path, |
|
"model_path": model_path, |
|
"batch_size": batch_size, |
|
"hidden_size": hidden_size, |
|
"latent_size": latent_size, |
|
"depth": depth, |
|
"lr": lr, |
|
"gamma": gamma, |
|
"max_epoch": max_epoch, |
|
"num_workers": num_workers, |
|
"print_iter": print_iter, |
|
"save_iter": save_iter, |
|
"use_cpu": False, |
|
"hash": seml.utils.make_hash(ex.current_run.config), |
|
"beta": beta, |
|
} |
|
) |
|
results = vaetrain.main(args) |
|
return results |
|
|
|
|
|
|
|
|
|
@ex.command(unobserved=True) |
|
def get_experiment(init_all=False): |
|
print("get_experiment") |
|
experiment = ExperimentWrapper(init_all=init_all) |
|
return experiment |
|
|
|
|
|
|
|
|
|
@ex.automain |
|
def train(experiment=None): |
|
if experiment is None: |
|
experiment = ExperimentWrapper() |
|
return experiment.train() |
|
|