|
from pathlib import Path |
|
from pprint import pprint |
|
|
|
from seml.config import generate_configs, read_config |
|
|
|
from chemCPA.experiments_run import ExperimentWrapper |
|
|
|
if __name__ == "__main__": |
|
exp = ExperimentWrapper(init_all=False) |
|
|
|
|
|
config = "test_config_biolord.yaml" |
|
assert Path(config).exists(), "config file not found" |
|
seml_config, slurm_config, experiment_config = read_config(config) |
|
|
|
configs = generate_configs(experiment_config) |
|
if len(configs) > 1: |
|
print("Careful, more than one config generated from the yaml file") |
|
args = configs[0] |
|
pprint(args) |
|
|
|
exp.seed = 1337 |
|
|
|
exp.init_dataset(**args["dataset"]) |
|
|
|
exp.init_drug_embedding(embedding=args["model"]["embedding"]) |
|
exp.init_model( |
|
hparams=args["model"]["hparams"], |
|
additional_params=args["model"]["additional_params"], |
|
load_pretrained=args["model"]["load_pretrained"], |
|
append_ae_layer=args["model"]["append_ae_layer"], |
|
enable_cpa_mode=args["model"]["enable_cpa_mode"], |
|
pretrained_model_path=args["model"]["pretrained_model_path"], |
|
pretrained_model_hashes=args["model"]["pretrained_model_hashes"], |
|
) |
|
|
|
exp.update_datasets() |
|
|
|
exp.train(**args["training"]) |
|
|