{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Running hyperparameter optimization on Chemprop model using RayTune" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/hpopting.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install chemprop from GitHub if running in Google Colab\n", "import os\n", "\n", "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", " try:\n", " import chemprop\n", " except ImportError:\n", " !git clone https://github.com/chemprop/chemprop.git\n", " %cd chemprop\n", " !pip install \".[hpopt]\"\n", " %cd examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "2024-10-22 09:03:28,414\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", "2024-10-22 09:03:28,801\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", "2024-10-22 09:03:29,333\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "from lightning import pytorch as pl\n", "import ray\n", "from ray import tune\n", "from ray.train import CheckpointConfig, RunConfig, ScalingConfig\n", "from ray.train.lightning import (RayDDPStrategy, RayLightningEnvironment,\n", " RayTrainReportCallback, prepare_trainer)\n", "from ray.train.torch import TorchTrainer\n", "from ray.tune.search.hyperopt import HyperOptSearch\n", "from ray.tune.search.optuna import OptunaSearch\n", "from ray.tune.schedulers import FIFOScheduler\n", "\n", "from chemprop import data, featurizers, models, nn" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n", "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n", "smiles_column = 'smiles' # name of the column containing SMILES strings\n", "target_columns = ['lipo'] # list of names of the columns containing targets\n", "\n", "hpopt_save_dir = Path.cwd() / \"hpopt\" # directory to save hyperopt results\n", "hpopt_save_dir.mkdir(exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smileslipo
0Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc143.54
1COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...-1.18
2COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl3.69
3OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...3.37
4Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...3.10
.........
95CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...2.20
96CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...2.04
97CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...4.49
98COc1ccc(Cc2c(N)n[nH]c2N)cc10.20
99CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...2.00
\n", "

100 rows × 2 columns

\n", "
" ], "text/plain": [ " smiles lipo\n", "0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 3.54\n", "1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n", "2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 3.69\n", "3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 3.37\n", "4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 3.10\n", ".. ... ...\n", "95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... 2.20\n", "96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... 2.04\n", "97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... 4.49\n", "98 COc1ccc(Cc2c(N)n[nH]c2N)cc1 0.20\n", "99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... 2.00\n", "\n", "[100 rows x 2 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_input = pd.read_csv(input_path)\n", "df_input" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "smis = df_input.loc[:, smiles_column].values\n", "ys = df_input.loc[:, target_columns].values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make data points, splits, and datasets" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits\n", "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n", "train_data, val_data, test_data = data.split_data_by_indices(\n", " all_data, train_indices, val_indices, test_indices\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n", "\n", "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n", "scaler = train_dset.normalize_targets()\n", "\n", "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n", "val_dset.normalize_targets(scaler)\n", "\n", "test_dset = data.MoleculeDataset(test_data[0], featurizer)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define helper function to train the model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def train_model(config, train_dset, val_dset, num_workers, scaler):\n", "\n", " # config is a dictionary containing hyperparameters used for the trial\n", " depth = int(config[\"depth\"])\n", " ffn_hidden_dim = int(config[\"ffn_hidden_dim\"])\n", " ffn_num_layers = int(config[\"ffn_num_layers\"])\n", " message_hidden_dim = int(config[\"message_hidden_dim\"])\n", "\n", " train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=True)\n", " val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n", "\n", " mp = nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)\n", " agg = nn.MeanAggregation()\n", " output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)\n", " ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers)\n", " batch_norm = True\n", " metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]\n", " model = models.MPNN(mp, agg, ffn, batch_norm, metric_list)\n", "\n", " trainer = pl.Trainer(\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=20, # number of epochs to train for\n", " # below are needed for Ray and Lightning integration\n", " strategy=RayDDPStrategy(),\n", " callbacks=[RayTrainReportCallback()],\n", " plugins=[RayLightningEnvironment()],\n", " )\n", "\n", " trainer = prepare_trainer(trainer)\n", " trainer.fit(model, train_loader, val_loader)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define parameter search space" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "search_space = {\n", " \"depth\": tune.qrandint(lower=2, upper=6, q=1),\n", " \"ffn_hidden_dim\": tune.qrandint(lower=300, upper=2400, q=100),\n", " \"ffn_num_layers\": tune.qrandint(lower=1, upper=3, q=1),\n", " \"message_hidden_dim\": tune.qrandint(lower=300, upper=2400, q=100),\n", "}" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2024-10-22 09:05:01
Running for: 00:01:23.70
Memory: 10.9/15.3 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 2.0/12 CPUs, 0/0 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc train_loop_config/de\n", "pth train_loop_config/ff\n", "n_hidden_dim train_loop_config/ff\n", "n_num_layers train_loop_config/me\n", "ssage_hidden_dim iter total time (s) train_loss train_loss_step val/rmse
TorchTrainer_f1a6e41aTERMINATED172.31.231.162:24873220002500 20 49.8815 0.0990423 0.168217 0.861368
TorchTrainer_d775c15dTERMINATED172.31.231.162:24953222002400 20 56.6533 0.069695 0.119898 0.90258
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Setting up process group for: env:// [rank=0, world_size=1]\n", "\u001b[36m(TorchTrainer pid=24873)\u001b[0m Started distributed worker processes: \n", "\u001b[36m(TorchTrainer pid=24873)\u001b[0m - (ip=172.31.231.162, pid=24952) world_rank=0, local_rank=0, node_rank=0\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m GPU available: False, used: False\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m TPU available: False, using: 0 TPU cores\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Sanity Checking DataLoader 0: 0%| | 0/1 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000001)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: 100%|██████████| 2/2 [00:01<00:00, 1.90it/s, v_num=0, train_loss_step=0.406, val_loss=0.904, train_loss_epoch=0.869]\n", "Epoch 2: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000002)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2: 100%|██████████| 2/2 [00:01<00:00, 1.29it/s, v_num=0, train_loss_step=1.290, val_loss=0.842, train_loss_epoch=1.210]\n", "Epoch 3: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3: 100%|██████████| 2/2 [00:01<00:00, 1.62it/s, v_num=0, train_loss_step=0.749, val_loss=0.912, train_loss_epoch=0.861]\n", "Epoch 4: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000004)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4: 100%|██████████| 2/2 [00:01<00:00, 1.31it/s, v_num=0, train_loss_step=0.578, val_loss=0.912, train_loss_epoch=0.792]\n", "Epoch 5: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000005)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5: 100%|██████████| 2/2 [00:01<00:00, 1.59it/s, v_num=0, train_loss_step=0.751, val_loss=0.887, train_loss_epoch=0.618]\n", "Epoch 6: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6: 100%|██████████| 2/2 [00:01<00:00, 1.53it/s, v_num=0, train_loss_step=0.569, val_loss=0.876, train_loss_epoch=0.450]\n", "Epoch 7: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:15,207\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7: 50%|█████ | 1/2 [00:00<00:00, 2.28it/s, v_num=0, train_loss_step=0.339, val_loss=0.876, train_loss_epoch=0.450]\u001b[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n", "Epoch 1: 100%|██████████| 2/2 [00:00<00:00, 3.75it/s, v_num=0, train_loss_step=0.335, val_loss=0.854, train_loss_epoch=1.010]\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: 100%|██████████| 2/2 [00:00<00:00, 2.01it/s, v_num=0, train_loss_step=0.335, val_loss=0.893, train_loss_epoch=0.703]\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "Epoch 2: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:17,399\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000009)\u001b[32m [repeated 6x across cluster]\u001b[0m\n", "2024-10-22 09:04:17,944\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:18,760\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:19,250\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:20,250\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 11: 50%|█████ | 1/2 [00:00<00:00, 1.25it/s, v_num=0, train_loss_step=0.175, val_loss=0.897, train_loss_epoch=0.258]\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "Epoch 11: 100%|██████████| 2/2 [00:01<00:00, 1.79it/s, v_num=0, train_loss_step=0.312, val_loss=0.897, train_loss_epoch=0.258]\u001b[32m [repeated 7x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \u001b[32m [repeated 11x across cluster]\u001b[0m\n", "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 7.84it/s]\u001b[A\u001b[32m [repeated 7x across cluster]\u001b[0m\n", "Epoch 11: 100%|██████████| 2/2 [00:01<00:00, 1.56it/s, v_num=0, train_loss_step=0.312, val_loss=0.869, train_loss_epoch=0.258]\u001b[32m [repeated 7x across cluster]\u001b[0m\n", "Epoch 11: 100%|██████████| 2/2 [00:01<00:00, 1.27it/s, v_num=0, train_loss_step=0.312, val_loss=0.869, train_loss_epoch=0.203]\u001b[32m [repeated 7x across cluster]\u001b[0m\n", "Epoch 12: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:22,323\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:22,766\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000013)\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "2024-10-22 09:04:24,404\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:25,524\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 14: 50%|█████ | 1/2 [00:01<00:01, 0.88it/s, v_num=0, train_loss_step=0.131, val_loss=0.841, train_loss_epoch=0.141] \u001b[32m [repeated 6x across cluster]\u001b[0m\n", "Epoch 7: 100%|██████████| 2/2 [00:01<00:00, 1.13it/s, v_num=0, train_loss_step=0.368, val_loss=0.836, train_loss_epoch=0.399]\u001b[32m [repeated 5x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:28,260\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000015)\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "2024-10-22 09:04:30,172\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9: 50%|█████ | 1/2 [00:01<00:01, 0.72it/s, v_num=0, train_loss_step=0.216, val_loss=0.889, train_loss_epoch=0.254]\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "Epoch 9: 100%|██████████| 2/2 [00:01<00:00, 1.04it/s, v_num=0, train_loss_step=0.322, val_loss=0.889, train_loss_epoch=0.254]\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 9x across cluster]\u001b[0m\n", "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 4.73it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "Epoch 9: 100%|██████████| 2/2 [00:02<00:00, 0.90it/s, v_num=0, train_loss_step=0.322, val_loss=0.910, train_loss_epoch=0.254]\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "Epoch 9: 100%|██████████| 2/2 [00:02<00:00, 0.70it/s, v_num=0, train_loss_step=0.322, val_loss=0.910, train_loss_epoch=0.237]\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "Epoch 16: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:33,534\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:34,844\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d/checkpoint_000011)\u001b[32m [repeated 5x across cluster]\u001b[0m\n", "2024-10-22 09:04:35,472\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18: 50%|█████ | 1/2 [00:01<00:01, 0.98it/s, v_num=0, train_loss_step=0.0962, val_loss=0.781, train_loss_epoch=0.116]\u001b[32m [repeated 5x across cluster]\u001b[0m\n", "Epoch 11: 100%|██████████| 2/2 [00:01<00:00, 1.91it/s, v_num=0, train_loss_step=0.263, val_loss=0.889, train_loss_epoch=0.219]\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:38,006\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000019)\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "2024-10-22 09:04:40,708\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:04:41,380\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m `Trainer.fit` stopped: `max_epochs=20` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 13: 50%|█████ | 1/2 [00:00<00:00, 1.17it/s, v_num=0, train_loss_step=0.118, val_loss=0.849, train_loss_epoch=0.122]\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "Epoch 13: 100%|██████████| 2/2 [00:01<00:00, 1.62it/s, v_num=0, train_loss_step=0.0846, val_loss=0.849, train_loss_epoch=0.122]\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15: 50%|█████ | 1/2 [00:01<00:01, 0.64it/s, v_num=0, train_loss_step=0.0923, val_loss=0.839, train_loss_epoch=0.0974]\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "Epoch 15: 100%|██████████| 2/2 [00:02<00:00, 0.94it/s, v_num=0, train_loss_step=0.0867, val_loss=0.839, train_loss_epoch=0.0974]\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "Validation: | | 0/? [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15: 100%|██████████| 2/2 [00:03<00:00, 0.54it/s, v_num=0, train_loss_step=0.0867, val_loss=0.837, train_loss_epoch=0.0912]\n", "Epoch 16: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16: 100%|██████████| 2/2 [00:04<00:00, 0.41it/s, v_num=0, train_loss_step=0.0703, val_loss=0.837, train_loss_epoch=0.0774]\n", "Epoch 17: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 17: 100%|██████████| 2/2 [00:01<00:00, 1.01it/s, v_num=0, train_loss_step=0.156, val_loss=0.836, train_loss_epoch=0.0882]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18: 100%|██████████| 2/2 [00:01<00:00, 1.32it/s, v_num=0, train_loss_step=0.064, val_loss=0.830, train_loss_epoch=0.0675]\n", "Epoch 19: 0%| | 0/2 [00:00= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 19: 100%|██████████| 2/2 [00:01<00:00, 1.55it/s, v_num=0, train_loss_step=0.120, val_loss=0.815, train_loss_epoch=0.0697]\n", "Epoch 19: 100%|██████████| 2/2 [00:01<00:00, 1.13it/s, v_num=0, train_loss_step=0.120, val_loss=0.815, train_loss_epoch=0.0697]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m `Trainer.fit` stopped: `max_epochs=20` reached.\n", "2024-10-22 09:05:01,809\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", "2024-10-22 09:05:01,823\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37' in 0.0305s.\n", "2024-10-22 09:05:01,873\tINFO tune.py:1048 -- Total run time: 83.87 seconds (83.66 seconds for the tuning loop).\n" ] } ], "source": [ "ray.init()\n", "\n", "scheduler = FIFOScheduler()\n", "\n", "# Scaling config controls the resources used by Ray\n", "scaling_config = ScalingConfig(\n", " num_workers=1,\n", " use_gpu=False, # change to True if you want to use GPU\n", ")\n", "\n", "# Checkpoint config controls the checkpointing behavior of Ray\n", "checkpoint_config = CheckpointConfig(\n", " num_to_keep=1, # number of checkpoints to keep\n", " checkpoint_score_attribute=\"val_loss\", # Save the checkpoint based on this metric\n", " checkpoint_score_order=\"min\", # Save the checkpoint with the lowest metric value\n", ")\n", "\n", "run_config = RunConfig(\n", " checkpoint_config=checkpoint_config,\n", " storage_path=hpopt_save_dir / \"ray_results\", # directory to save the results\n", ")\n", "\n", "ray_trainer = TorchTrainer(\n", " lambda config: train_model(\n", " config, train_dset, val_dset, num_workers, scaler\n", " ),\n", " scaling_config=scaling_config,\n", " run_config=run_config,\n", ")\n", "\n", "search_alg = HyperOptSearch(\n", " n_initial_points=1, # number of random evaluations before tree parzen estimators\n", " random_state_seed=42,\n", ")\n", "\n", "# OptunaSearch is another search algorithm that can be used\n", "# search_alg = OptunaSearch() \n", "\n", "tune_config = tune.TuneConfig(\n", " metric=\"val_loss\",\n", " mode=\"min\",\n", " num_samples=2, # number of trials to run\n", " scheduler=scheduler,\n", " search_alg=search_alg,\n", " trial_dirname_creator=lambda trial: str(trial.trial_id), # shorten filepaths\n", " \n", ")\n", "\n", "tuner = tune.Tuner(\n", " ray_trainer,\n", " param_space={\n", " \"train_loop_config\": search_space,\n", " },\n", " tune_config=tune_config,\n", ")\n", "\n", "# Start the hyperparameter search\n", "results = tuner.fit()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter optimization results" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ResultGrid<[\n", " Result(\n", " metrics={'train_loss': 0.09904231131076813, 'train_loss_step': 0.16821686923503876, 'val/rmse': 0.8613682389259338, 'val/mae': 0.7006751298904419, 'val_loss': 0.7419552206993103, 'train_loss_epoch': 0.09904231131076813, 'epoch': 19, 'step': 40},\n", " path='/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a',\n", " filesystem='local',\n", " checkpoint=Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000019)\n", " ),\n", " Result(\n", " metrics={'train_loss': 0.06969495117664337, 'train_loss_step': 0.11989812552928925, 'val/rmse': 0.902579665184021, 'val/mae': 0.7176367044448853, 'val_loss': 0.8146500587463379, 'train_loss_epoch': 0.06969495117664337, 'epoch': 19, 'step': 40},\n", " path='/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d',\n", " filesystem='local',\n", " checkpoint=Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d/checkpoint_000019)\n", " )\n", "]>" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
train_losstrain_loss_stepval/rmseval/maeval_losstrain_loss_epochepochsteptimestampcheckpoint_dir_name...pidhostnamenode_iptime_since_restoreiterations_since_restoreconfig/train_loop_config/depthconfig/train_loop_config/ffn_hidden_dimconfig/train_loop_config/ffn_num_layersconfig/train_loop_config/message_hidden_dimlogdir
00.0990420.1682170.8613680.7006750.7419550.09904219401729602279checkpoint_000019...24873Knathan-Laptop172.31.231.16249.88151620220002500f1a6e41a
10.0696950.1198980.9025800.7176370.8146500.06969519401729602299checkpoint_000019...24953Knathan-Laptop172.31.231.16256.65333620222002400d775c15d
\n", "

2 rows × 27 columns

\n", "
" ], "text/plain": [ " train_loss train_loss_step val/rmse val/mae val_loss \\\n", "0 0.099042 0.168217 0.861368 0.700675 0.741955 \n", "1 0.069695 0.119898 0.902580 0.717637 0.814650 \n", "\n", " train_loss_epoch epoch step timestamp checkpoint_dir_name ... pid \\\n", "0 0.099042 19 40 1729602279 checkpoint_000019 ... 24873 \n", "1 0.069695 19 40 1729602299 checkpoint_000019 ... 24953 \n", "\n", " hostname node_ip time_since_restore iterations_since_restore \\\n", "0 Knathan-Laptop 172.31.231.162 49.881516 20 \n", "1 Knathan-Laptop 172.31.231.162 56.653336 20 \n", "\n", " config/train_loop_config/depth config/train_loop_config/ffn_hidden_dim \\\n", "0 2 2000 \n", "1 2 2200 \n", "\n", " config/train_loop_config/ffn_num_layers \\\n", "0 2 \n", "1 2 \n", "\n", " config/train_loop_config/message_hidden_dim logdir \n", "0 500 f1a6e41a \n", "1 400 d775c15d \n", "\n", "[2 rows x 27 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# results of all trials\n", "result_df = results.get_dataframe()\n", "result_df" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'depth': 2,\n", " 'ffn_hidden_dim': 2000,\n", " 'ffn_num_layers': 2,\n", " 'message_hidden_dim': 500}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# best configuration\n", "best_result = results.get_best_result()\n", "best_config = best_result.config\n", "best_config['train_loop_config']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best model checkpoint path: /home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000019/checkpoint.ckpt\n" ] } ], "source": [ "# best model checkpoint path\n", "best_result = results.get_best_result()\n", "best_checkpoint_path = Path(best_result.checkpoint.path) / \"checkpoint.ckpt\"\n", "print(f\"Best model checkpoint path: {best_checkpoint_path}\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "ray.shutdown()" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }