chemCPA / experiments /baseline_comparison /baseline_experiment.yaml
github-actions[bot]
HF snapshot
a48f0ae
# Config for finetuning-tuning (on sci-PLex) a chemCPA model that was petrained on L1000
seml:
executable: chemCPA/experiments_run.py
name: 👾_baseline_comparison
output_dir: project_folder/logs
conda_environment: chemical_CPA
project_root_dir: ../..
slurm:
max_simultaneous_jobs: 19
experiments_per_job: 2
sbatch_options_template: GPU
sbatch_options:
gres: gpu:1 # num GPUs
mem: 32G # memory
cpus-per-task: 6 # num cores
# speeds is roughly 3 epochs / minute
time: 1-00:01 # max time, D-HH:MM
nice: 10
###### BEGIN PARAMETER CONFIGURATION ######
fixed:
profiling.run_profiler: False
profiling.outdir: "./"
training.checkpoint_freq: 25 # checkpoint frequency to run evaluate, and maybe save checkpoint
training.num_epochs: 201 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps.
training.max_minutes: 1200 # maximum computation time
training.full_eval_during_train: False
training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval
training.run_eval_r2: True
training.run_eval_r2_sc: False
training.run_eval_logfold: False
training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS.
training.save_dir: project_folder/sweeps/checkpoints
dataset.dataset_type: trapnell
dataset.data_params.perturbation_key: condition # stores name of the drug
dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose
dataset.data_params.dose_key: dose # stores drug dose as a float
dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present.
dataset.data_params.smiles_key: SMILES
dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead
dataset.data_params.dataset_path: project_folder/datasets/adata_baseline.h5ad # full path to the anndata dataset
dataset.data_params.degs_key: lincs_DEGs # `uns` column name denoting the DEGs for each perturbation
# model.load_pretrained: True
model.pretrained_model_path: project_folder/checkpoints
model.pretrained_model_hashes: # seml config_hashes for the pretrained models for each embedding. Used for loading model checkpoints. Hashes taken from `analyze_lincs_all_embeddings_hparam.ipynb`
grover_base: ff420aea264fca7668ecb147f60762a1
# MPNN: ff9629a1b216372be8b205556cabc6fb
rdkit: 4f061dbfc7af05cf84f06a724b0c8563
# weave: 1244d8b476696a7e1c01fd05d73d7450
jtvae: a7060ac4e2c6154e64a13acd414cbba2
# seq2seq: e31119adc782888d5b75c57f8c803ee0
# GCN: 6b078a999327392c2d1b34c96154e317
# vanilla: ba3569d1f5898a6bb964b7fafbed2641 # Vanilla CPA, new embedding will be trained.
model.additional_params.patience: 50 # patience for early stopping. Effective epochs: patience * checkpoint_freq.
model.additional_params.decoder_activation: ReLU # last layer of the decoder 'linear' or 'ReLU'
model.additional_params.doser_type: amortized # non-linearity for doser function
model.embedding.directory: project_folder/embeddings # null will load the path from paths.py
model.additional_params.seed: 1337
# these were picked in the `lincs_rdkit_hparam` experiment
model.hparams.dim: 32
model.hparams.dropout: 0.262378
model.hparams.autoencoder_width: 256
model.hparams.autoencoder_depth: 4
model.hparams.reg_multi_task: 0
random:
samples: 10
seed: 42
model.hparams.batch_size:
type: choice
options:
- 32
# - 64 # sciplex_hparam indicates 32 is best
# - 128
model.hparams.autoencoder_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.autoencoder_wd:
type: loguniform
min: 1e-8
max: 1e-5
model.hparams.adversary_width:
type: choice
options:
- 64 # results in b:4 indicate that 256 is best for grover and rdkit
- 128
- 256 # results in b:4 indicate that 256 is best for grover and rdkit
model.hparams.adversary_depth:
type: choice
options:
- 2
- 3
- 4
model.hparams.adversary_lr:
type: loguniform
min: 5e-5
max: 1e-2
model.hparams.adversary_wd:
type: loguniform
min: 1e-8
max: 1e-3
model.hparams.adversary_steps: # every X steps, update the adversary INSTEAD OF the autoencoder.
type: choice
options:
- 2
- 3
model.hparams.reg_adversary:
type: loguniform
min: 1
max: 40
model.hparams.reg_adversary_cov:
type: loguniform
min: 3
max: 50
model.hparams.penalty_adversary:
type: loguniform
min: 0.05
max: 2
model.hparams.dosers_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.dosers_wd:
type: loguniform
min: 1e-8
max: 1e-5
# rdkit_all_genes:
# fixed:
# model.embedding.model: rdkit
# model.hparams.dosers_width: 64
# model.hparams.dosers_depth: 3
# # model.hparams.dosers_lr: 0.001121
# # model.hparams.dosers_wd: 3.752056e-7
# model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
# model.hparams.embedding_encoder_width: 128
# model.hparams.embedding_encoder_depth: 4
# # model.hparams.adversary_width: 256
# model.append_ae_layer: True
# dataset.data_params.dataset_path: project_folder/datasets/adata_baseline.h5ad # full path to the anndata dataset
# dataset.data_params.degs_key: all_DEGs # `uns` column name denoting the DEGs for each perturbation
grid:
dataset.data_params.split_key:
type: choice
options:
- split_baseline_A549
- split_baseline_K562
- split_baseline_MCF7
vanilla_CPA_model:
fixed:
model.load_pretrained: False
model.embedding.model: vanilla
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.additional_params.doser_type: sigm
model.append_ae_layer: False
model.enable_cpa_mode: True
rdkit_lincs_genes:
grid:
model.load_pretrained:
type: choice
options:
- True
- False
fixed:
model.embedding.model: rdkit
model.hparams.dosers_width: 64
model.hparams.dosers_depth: 3
# model.hparams.dosers_lr: 0.001121
# model.hparams.dosers_wd: 3.752056e-7
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 128
model.hparams.embedding_encoder_depth: 4
# model.hparams.adversary_width: 256
model.append_ae_layer: False
model.enable_cpa_mode: False