File size: 1,391 Bytes
a48f0ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from pathlib import Path
from pprint import pprint
from seml.config import generate_configs, read_config
from chemCPA.experiments_run import ExperimentWrapper
if __name__ == "__main__":
exp = ExperimentWrapper(init_all=False)
# this is how seml loads the config file internally
config = "test_config_biolord.yaml"
assert Path(config).exists(), "config file not found"
seml_config, slurm_config, experiment_config = read_config(config)
# we take the first config generated
configs = generate_configs(experiment_config)
if len(configs) > 1:
print("Careful, more than one config generated from the yaml file")
args = configs[0]
pprint(args)
exp.seed = 1337
# loads the dataset splits
exp.init_dataset(**args["dataset"])
exp.init_drug_embedding(embedding=args["model"]["embedding"])
exp.init_model(
hparams=args["model"]["hparams"],
additional_params=args["model"]["additional_params"],
load_pretrained=args["model"]["load_pretrained"],
append_ae_layer=args["model"]["append_ae_layer"],
enable_cpa_mode=args["model"]["enable_cpa_mode"],
pretrained_model_path=args["model"]["pretrained_model_path"],
pretrained_model_hashes=args["model"]["pretrained_model_hashes"],
)
# setup the torch DataLoader
exp.update_datasets()
exp.train(**args["training"])
|