chemCPA / experiments /sciplex_hparam /config_sciplex_rdkit_hparam.yaml
github-actions[bot]
HF snapshot
a48f0ae
# Config for fine-tuning CCPA on Trapnell from LINCS pre-training with identical gene sets
# This leads direclty to part 1 and 3 of `finetuning_num_genes` and also o `finetuning_OOD_prediction`
seml:
executable: chemCPA/seml_sweep_icb.py
name: ft_sciplex_hparam
output_dir: sweeps/logs
conda_environment: chemical_CPA
project_root_dir: ../..
slurm:
max_simultaneous_jobs: 17
experiments_per_job: 2
sbatch_options_template: GPU
sbatch_options:
gres: gpu:1 # num GPUs
mem: 32G # memory
cpus-per-task: 6 # num cores
# speeds is roughly 3 epochs / minute
time: 1-00:01 # max time, D-HH:MM
###### BEGIN PARAMETER CONFIGURATION ######
fixed:
profiling.run_profiler: False
profiling.outdir: "./"
training.checkpoint_freq: 15 # checkpoint frequency to run evaluate, and maybe save checkpoint
training.num_epochs: 500 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps.
training.max_minutes: 1200 # maximum computation time
training.full_eval_during_train: False
training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval
training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS.
training.save_dir: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/sweeps/checkpoints
dataset.dataset_type: trapnell
dataset.data_params.dataset_path: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/sciplex_complete_lincs_genes.h5ad # full path to the anndata dataset
dataset.data_params.perturbation_key: condition # stores name of the drug
dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose
dataset.data_params.dose_key: dose # stores drug dose as a float
dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present.
dataset.data_params.smiles_key: SMILES
dataset.data_params.degs_key: lincs_DEGs # `uns` column name denoting the DEGs for each perturbation
dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead
# model.load_pretrained: True
model.pretrained_model_path: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/sweeps/checkpoints
model.pretrained_model_hashes: # seml config_hashes for the pretrained models for each embedding. Used for loading model checkpoints. Hashes taken from `analyze_lincs_all_embeddings_hparam.ipynb`
grover_base: ff420aea264fca7668ecb147f60762a1
# MPNN: ff9629a1b216372be8b205556cabc6fb
rdkit: 4f061dbfc7af05cf84f06a724b0c8563
# weave: 1244d8b476696a7e1c01fd05d73d7450
# jtvae: a7060ac4e2c6154e64a13acd414cbba2
# seq2seq: e31119adc782888d5b75c57f8c803ee0
# GCN: aedb25c686fb856e574a951f749b8dcf
# vanilla: ba3569d1f5898a6bb964b7fafbed2641 # Vanilla CPA, new embedding will be trained.
model.additional_params.patience: 5 # patience for early stopping. Effective epochs: patience * checkpoint_freq.
model.additional_params.decoder_activation: linear # last layer of the decoder
model.additional_params.doser_type: amortized # non-linearity for doser function
model.embedding.directory: null # null will load the path from paths.py
model.additional_params.seed: 1337
# these were picked in the `lincs_rdkit_hparam` experiment
model.hparams.dim: 32
model.hparams.dropout: 0.262378
model.hparams.autoencoder_width: 256
model.hparams.autoencoder_depth: 4
random:
samples: 15
seed: 42
model.hparams.batch_size:
type: choice
options:
- 32
- 64
- 128
model.hparams.autoencoder_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.autoencoder_wd:
type: loguniform
min: 1e-8
max: 1e-5
model.hparams.adversary_width:
type: choice
options:
- 64
- 128
- 256
model.hparams.adversary_depth:
type: choice
options:
- 2
- 3
- 4
model.hparams.adversary_lr:
type: loguniform
min: 5e-5
max: 1e-2
model.hparams.adversary_wd:
type: loguniform
min: 1e-8
max: 1e-3
model.hparams.adversary_steps: # every X steps, update the adversary INSTEAD OF the autoencoder.
type: choice
options:
- 2
- 3
model.hparams.reg_adversary:
type: loguniform
min: 5
max: 100
model.hparams.penalty_adversary:
type: loguniform
min: 0.5
max: 5
model.hparams.dosers_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.dosers_wd:
type: loguniform
min: 1e-8
max: 1e-5
grid:
model.load_pretrained:
type: choice
options:
- True
- False
dataset.data_params.split_key:
type: choice
options:
- split_ho_pathway # necessary field for train, test, ood splits.
- split_ood_finetuning # necessary field for train, test, ood splits.
rdkit:
fixed:
model.embedding.model: rdkit
model.hparams.dosers_width: 64
model.hparams.dosers_depth: 3
# Params for corresponding hash - left in for reference
# model.hparams.dosers_lr: 0.001121
# model.hparams.dosers_wd: 3.752056e-7
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 128
model.hparams.embedding_encoder_depth: 4
grover_base:
fixed:
model.embedding.model: grover_base
model.hparams.dosers_width: 512
model.hparams.dosers_depth: 2
# Params for corresponding hash - left in for reference
# model.hparams.dosers_lr: 0.000561
# model.hparams.dosers_wd: 1.329292e-07
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 512
model.hparams.embedding_encoder_depth: 3