File size: 5,841 Bytes
a48f0ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Config for fine-tuning CCPA on Trapnell from LINCS pre-training with identical gene sets
# This leads direclty to part 1 and 3 of `finetuning_num_genes` and also o `finetuning_OOD_prediction`
seml:
executable: chemCPA/seml_sweep_icb.py
name: ft_sciplex_hparam
output_dir: sweeps/logs
conda_environment: chemical_CPA
project_root_dir: ../..
slurm:
max_simultaneous_jobs: 17
experiments_per_job: 2
sbatch_options_template: GPU
sbatch_options:
gres: gpu:1 # num GPUs
mem: 32G # memory
cpus-per-task: 6 # num cores
# speeds is roughly 3 epochs / minute
time: 1-00:01 # max time, D-HH:MM
###### BEGIN PARAMETER CONFIGURATION ######
fixed:
profiling.run_profiler: False
profiling.outdir: "./"
training.checkpoint_freq: 15 # checkpoint frequency to run evaluate, and maybe save checkpoint
training.num_epochs: 500 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps.
training.max_minutes: 1200 # maximum computation time
training.full_eval_during_train: False
training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval
training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS.
training.save_dir: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/sweeps/checkpoints
dataset.dataset_type: trapnell
dataset.data_params.dataset_path: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/sciplex_complete_lincs_genes.h5ad # full path to the anndata dataset
dataset.data_params.perturbation_key: condition # stores name of the drug
dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose
dataset.data_params.dose_key: dose # stores drug dose as a float
dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present.
dataset.data_params.smiles_key: SMILES
dataset.data_params.degs_key: lincs_DEGs # `uns` column name denoting the DEGs for each perturbation
dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead
# model.load_pretrained: True
model.pretrained_model_path: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/sweeps/checkpoints
model.pretrained_model_hashes: # seml config_hashes for the pretrained models for each embedding. Used for loading model checkpoints. Hashes taken from `analyze_lincs_all_embeddings_hparam.ipynb`
grover_base: ff420aea264fca7668ecb147f60762a1
# MPNN: ff9629a1b216372be8b205556cabc6fb
rdkit: 4f061dbfc7af05cf84f06a724b0c8563
# weave: 1244d8b476696a7e1c01fd05d73d7450
# jtvae: a7060ac4e2c6154e64a13acd414cbba2
# seq2seq: e31119adc782888d5b75c57f8c803ee0
# GCN: aedb25c686fb856e574a951f749b8dcf
# vanilla: ba3569d1f5898a6bb964b7fafbed2641 # Vanilla CPA, new embedding will be trained.
model.additional_params.patience: 5 # patience for early stopping. Effective epochs: patience * checkpoint_freq.
model.additional_params.decoder_activation: linear # last layer of the decoder
model.additional_params.doser_type: amortized # non-linearity for doser function
model.embedding.directory: null # null will load the path from paths.py
model.additional_params.seed: 1337
# these were picked in the `lincs_rdkit_hparam` experiment
model.hparams.dim: 32
model.hparams.dropout: 0.262378
model.hparams.autoencoder_width: 256
model.hparams.autoencoder_depth: 4
random:
samples: 15
seed: 42
model.hparams.batch_size:
type: choice
options:
- 32
- 64
- 128
model.hparams.autoencoder_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.autoencoder_wd:
type: loguniform
min: 1e-8
max: 1e-5
model.hparams.adversary_width:
type: choice
options:
- 64
- 128
- 256
model.hparams.adversary_depth:
type: choice
options:
- 2
- 3
- 4
model.hparams.adversary_lr:
type: loguniform
min: 5e-5
max: 1e-2
model.hparams.adversary_wd:
type: loguniform
min: 1e-8
max: 1e-3
model.hparams.adversary_steps: # every X steps, update the adversary INSTEAD OF the autoencoder.
type: choice
options:
- 2
- 3
model.hparams.reg_adversary:
type: loguniform
min: 5
max: 100
model.hparams.penalty_adversary:
type: loguniform
min: 0.5
max: 5
model.hparams.dosers_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.dosers_wd:
type: loguniform
min: 1e-8
max: 1e-5
grid:
model.load_pretrained:
type: choice
options:
- True
- False
dataset.data_params.split_key:
type: choice
options:
- split_ho_pathway # necessary field for train, test, ood splits.
- split_ood_finetuning # necessary field for train, test, ood splits.
rdkit:
fixed:
model.embedding.model: rdkit
model.hparams.dosers_width: 64
model.hparams.dosers_depth: 3
# Params for corresponding hash - left in for reference
# model.hparams.dosers_lr: 0.001121
# model.hparams.dosers_wd: 3.752056e-7
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 128
model.hparams.embedding_encoder_depth: 4
grover_base:
fixed:
model.embedding.model: grover_base
model.hparams.dosers_width: 512
model.hparams.dosers_depth: 2
# Params for corresponding hash - left in for reference
# model.hparams.dosers_lr: 0.000561
# model.hparams.dosers_wd: 1.329292e-07
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 512
model.hparams.embedding_encoder_depth: 3
|