|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import scanpy as sc |
|
import scipy |
|
import seaborn as sns |
|
import umap.plot |
|
from utils import ( |
|
compute_drug_embeddings, |
|
compute_pred, |
|
compute_pred_ctrl, |
|
load_config, |
|
load_dataset, |
|
load_model, |
|
load_smiles, |
|
) |
|
|
|
from chemCPA.data import load_dataset_splits |
|
from chemCPA.paths import FIGURE_DIR, ROOT |
|
|
|
matplotlib.style.use("fivethirtyeight") |
|
matplotlib.style.use("seaborn-talk") |
|
matplotlib.rcParams["font.family"] = "monospace" |
|
matplotlib.rcParams["figure.dpi"] = 300 |
|
matplotlib.pyplot.rcParams["savefig.facecolor"] = "white" |
|
sns.set_style("whitegrid") |
|
sns.set_context("poster") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seml_collection = "multi_task" |
|
|
|
model_hash_pretrained_rdkit = "dde01c1c58f398d524453c4b564a440f" |
|
model_hash_scratch_rdkit = "475e26950b2c531bea88931a4b2c27b7" |
|
|
|
model_hash_pretrained_grover = "0f4a3b11e1fbe3da58125f39ff6fb035" |
|
model_hash_scratch_grover = "b372147c80cf9ad4bd10d16bc56b7534" |
|
|
|
model_hash_pretrained_jtvae = "e4eac660c5830245f681ec3cc5099f21" |
|
model_hash_scratch_jtvae = "6b465400467f69da861e3ef0b4709e19" |
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_pretrained_rdkit) |
|
|
|
config["dataset"]["data_params"]["dataset_path"] = ( |
|
ROOT / config["dataset"]["data_params"]["dataset_path"] |
|
) |
|
|
|
dataset, key_dict = load_dataset(config) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
|
|
|
|
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles( |
|
config, dataset, key_dict, True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
ood_drugs = ( |
|
dataset.obs.condition[ |
|
dataset.obs[config["dataset"]["data_params"]["split_key"]].isin(["ood"]) |
|
] |
|
.unique() |
|
.to_list() |
|
) |
|
|
|
|
|
|
|
|
|
|
|
config["dataset"]["data_params"] |
|
|
|
|
|
data_params = config["dataset"]["data_params"] |
|
datasets = load_dataset_splits(**data_params, return_dataset=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dosages = [1e1, 1e2, 1e3, 1e4] |
|
cell_lines = ["A549", "K562", "MCF7"] |
|
use_DEGs = True |
|
|
|
|
|
drug_r2_baseline_degs, _ = compute_pred_ctrl( |
|
dataset=datasets["ood"], |
|
dataset_ctrl=datasets["test_control"], |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_baseline_all, _ = compute_pred_ctrl( |
|
dataset=datasets["ood"], |
|
dataset_ctrl=datasets["test_control"], |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
ood_drugs |
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_pretrained_rdkit) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
model_pretrained_rdkit, embedding_pretrained_rdkit = load_model( |
|
config, canon_smiles_unique_sorted |
|
) |
|
|
|
|
|
drug_r2_pretrained_degs_rdkit, _ = compute_pred( |
|
model_pretrained_rdkit, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_pretrained_all_rdkit, _ = compute_pred( |
|
model_pretrained_rdkit, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_scratch_rdkit) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
model_scratch_rdkit, embedding_scratch_rdkit = load_model( |
|
config, canon_smiles_unique_sorted |
|
) |
|
|
|
|
|
drug_r2_scratch_degs_rdkit, _ = compute_pred( |
|
model_scratch_rdkit, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_scratch_all_rdkit, _ = compute_pred( |
|
model_scratch_rdkit, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_pretrained_grover) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
model_pretrained_grover, embedding_pretrained_grover = load_model( |
|
config, canon_smiles_unique_sorted |
|
) |
|
|
|
|
|
drug_r2_pretrained_degs_grover, _ = compute_pred( |
|
model_pretrained_grover, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_pretrained_all_grover, _ = compute_pred( |
|
model_pretrained_grover, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_scratch_grover) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
model_scratch_grover, embedding_scratch_grover = load_model( |
|
config, canon_smiles_unique_sorted |
|
) |
|
|
|
|
|
drug_r2_scratch_degs_grover, _ = compute_pred( |
|
model_scratch_grover, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_scratch_all_grover, _ = compute_pred( |
|
model_scratch_grover, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_pretrained_jtvae) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
model_pretrained_jtvae, embedding_pretrained_jtvae = load_model( |
|
config, canon_smiles_unique_sorted |
|
) |
|
|
|
|
|
drug_r2_pretrained_degs_jtvae, _ = compute_pred( |
|
model_pretrained_jtvae, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_pretrained_all_jtvae, _ = compute_pred( |
|
model_pretrained_jtvae, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_scratch_jtvae) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
model_scratch_jtvae, embedding_scratch_jtvae = load_model( |
|
config, canon_smiles_unique_sorted |
|
) |
|
|
|
|
|
drug_r2_scratch_degs_jtvae, _ = compute_pred( |
|
model_scratch_jtvae, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
drug_r2_scratch_all_jtvae, _ = compute_pred( |
|
model_scratch_jtvae, |
|
datasets["ood"], |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=False, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_df( |
|
drug_r2_baseline, |
|
drug_r2_pretrained_rdkit, |
|
drug_r2_scratch_rdkit, |
|
drug_r2_pretrained_grover, |
|
drug_r2_scratch_grover, |
|
drug_r2_pretrained_jtvae, |
|
drug_r2_scratch_jtvae, |
|
): |
|
df_baseline = pd.DataFrame.from_dict( |
|
drug_r2_baseline, orient="index", columns=["r2_de"] |
|
) |
|
df_baseline["type"] = "baseline" |
|
df_baseline["model"] = "baseline" |
|
|
|
df_pretrained_rdkit = pd.DataFrame.from_dict( |
|
drug_r2_pretrained_rdkit, orient="index", columns=["r2_de"] |
|
) |
|
df_pretrained_rdkit["type"] = "pretrained" |
|
df_pretrained_rdkit["model"] = "rdkit" |
|
df_scratch_rdkit = pd.DataFrame.from_dict( |
|
drug_r2_scratch_rdkit, orient="index", columns=["r2_de"] |
|
) |
|
df_scratch_rdkit["type"] = "non-pretrained" |
|
df_scratch_rdkit["model"] = "rdkit" |
|
|
|
df_pretrained_grover = pd.DataFrame.from_dict( |
|
drug_r2_pretrained_grover, orient="index", columns=["r2_de"] |
|
) |
|
df_pretrained_grover["type"] = "pretrained" |
|
df_pretrained_grover["model"] = "grover" |
|
df_scratch_grover = pd.DataFrame.from_dict( |
|
drug_r2_scratch_grover, orient="index", columns=["r2_de"] |
|
) |
|
df_scratch_grover["type"] = "non-pretrained" |
|
df_scratch_grover["model"] = "grover" |
|
|
|
df_pretrained_jtvae = pd.DataFrame.from_dict( |
|
drug_r2_pretrained_jtvae, orient="index", columns=["r2_de"] |
|
) |
|
df_pretrained_jtvae["type"] = "pretrained" |
|
df_pretrained_jtvae["model"] = "jtvae" |
|
df_scratch_jtvae = pd.DataFrame.from_dict( |
|
drug_r2_scratch_jtvae, orient="index", columns=["r2_de"] |
|
) |
|
df_scratch_jtvae["type"] = "non-pretrained" |
|
df_scratch_jtvae["model"] = "jtvae" |
|
|
|
df = pd.concat( |
|
[ |
|
df_baseline, |
|
df_pretrained_rdkit, |
|
df_scratch_rdkit, |
|
df_pretrained_grover, |
|
df_scratch_grover, |
|
df_pretrained_jtvae, |
|
df_scratch_jtvae, |
|
] |
|
) |
|
|
|
df["r2_de"] = df["r2_de"].apply(lambda x: max(x, 0)) |
|
|
|
df["cell_line"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[0]).values |
|
df["drug"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[1]).values |
|
df["dose"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[2]).values |
|
df["dose"] = df["dose"].astype(float) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = df.reset_index() |
|
return df |
|
|
|
|
|
|
|
df_degs = create_df( |
|
drug_r2_baseline_degs, |
|
drug_r2_pretrained_degs_rdkit, |
|
drug_r2_scratch_degs_rdkit, |
|
drug_r2_pretrained_degs_grover, |
|
drug_r2_scratch_degs_grover, |
|
drug_r2_pretrained_degs_jtvae, |
|
drug_r2_scratch_degs_jtvae, |
|
) |
|
df_all = create_df( |
|
drug_r2_baseline_all, |
|
drug_r2_pretrained_all_rdkit, |
|
drug_r2_scratch_all_rdkit, |
|
drug_r2_pretrained_all_grover, |
|
drug_r2_scratch_all_grover, |
|
drug_r2_pretrained_all_jtvae, |
|
drug_r2_scratch_all_jtvae, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
r2_degs_mean = [] |
|
for model, _df in df_degs.groupby(["model", "type", "dose"]): |
|
dose = model[2] |
|
if dose == 1.0: |
|
print(f"Model: {model}, R2 mean: {_df.r2_de.mean()}") |
|
r2_degs_mean.append(_df.r2_de.mean()) |
|
|
|
|
|
r2_all_mean = [] |
|
for model, _df in df_all.groupby(["model", "type", "dose"]): |
|
dose = model[2] |
|
if dose == 1.0: |
|
print(f"Model: {model}, R2 mean: {_df.r2_de.mean()}") |
|
r2_all_mean.append(_df.r2_de.mean()) |
|
|
|
|
|
r2_degs_median = [] |
|
for model, _df in df_degs.groupby(["model", "type", "dose"]): |
|
dose = model[2] |
|
if dose == 1.0: |
|
print(f"Model: {model}, R2 median: {_df.r2_de.median()}") |
|
r2_degs_median.append(_df.r2_de.median()) |
|
|
|
|
|
r2_all_median = [] |
|
model = [] |
|
model_type = [] |
|
for _model, _df in df_all.groupby(["model", "type", "dose"]): |
|
dose = _model[2] |
|
if dose == 1.0: |
|
print(f"Model: {_model}, R2 median: {_df.r2_de.median()}") |
|
r2_all_median.append(_df.r2_de.median()) |
|
model.append(_model[0]) |
|
model_type.append(_model[1]) |
|
|
|
|
|
|
|
|
|
|
|
df_dict = { |
|
"Model": model, |
|
"Type": model_type, |
|
"Mean $r^2$ all": r2_all_mean, |
|
"Mean $r^2$ DEGs": r2_degs_mean, |
|
"Median $r^2$ all": r2_all_median, |
|
"Median $r^2$ DEGs": r2_degs_median, |
|
} |
|
|
|
df = pd.DataFrame.from_dict(df_dict) |
|
df = df.set_index("Model") |
|
|
|
|
|
print(df.to_latex(float_format="%.2f")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dose = 1.0 |
|
vs_model = "baseline" |
|
|
|
models = [] |
|
gene_set = [] |
|
p_values = [] |
|
vs_models = [] |
|
|
|
|
|
for model in ["rdkit", "grover", "jtvae"]: |
|
for vs_model in ["baseline", "non-pretrained"]: |
|
_df = df_all[df_all.model.isin([vs_model, model])] |
|
_df = _df[_df.type.isin(["pretrained", vs_model]) & (_df.dose == dose)] |
|
|
|
stat, pvalue = scipy.stats.ttest_rel( |
|
_df[(_df.type == "pretrained") & (_df.dose == dose)].r2_de, |
|
_df[(_df.type == vs_model) & (_df.dose == dose)].r2_de, |
|
) |
|
|
|
models.append(model) |
|
gene_set.append("all genes") |
|
p_values.append(pvalue) |
|
vs_models.append(vs_model) |
|
|
|
_df = df_degs[df_degs.model.isin(["baseline", model])] |
|
_df = _df[_df.type.isin(["pretrained", vs_model]) & (_df.dose == dose)] |
|
|
|
stat, pvalue = scipy.stats.ttest_rel( |
|
_df[(_df.type == "pretrained") & (_df.dose == dose)].r2_de, |
|
_df[(_df.type == vs_model) & (_df.dose == dose)].r2_de, |
|
) |
|
|
|
models.append(model) |
|
gene_set.append("DEGs") |
|
p_values.append(pvalue) |
|
vs_models.append(vs_model) |
|
|
|
|
|
df_dict = { |
|
"Model $G$": models, |
|
"Against": vs_models, |
|
"Gene set": gene_set, |
|
"p-value": p_values, |
|
} |
|
|
|
df = pd.DataFrame.from_dict(df_dict) |
|
df = df.set_index("Model $G$") |
|
|
|
|
|
|
|
|
|
|
|
print(df.to_latex(float_format="%.4f")) |
|
|
|
|
|
|
|
|
|
|
|
|