chemCPA / experiments /dom_experiments /analyze_sciplex_runs.py
github-actions[bot]
HF snapshot
a48f0ae
# ---
# jupyter:
# jupytext:
# notebook_metadata_filter: -kernelspec
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.1
# ---
# %% [markdown] pycharm={"name": "#%% md\n"}
# # Analyzing the results for `finetuning_num_genes`
#
# This is part 1, the results of sweeping all hyperparameter for rdkit
# %% pycharm={"name": "#%%\n"}
from pathlib import Path
import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
import seml
from matplotlib import pyplot as plt
matplotlib.style.use("fivethirtyeight")
matplotlib.style.use("seaborn-talk")
matplotlib.rcParams["font.family"] = "monospace"
plt.rcParams["savefig.facecolor"] = "white"
sns.set_context("poster")
pd.set_option("display.max_columns", 100)
# %% pycharm={"name": "#%%\n"}
results = seml.get_results(
"check_seml",
to_data_frame=True,
fields=["config", "result", "seml", "config_hash"],
states=["COMPLETED"],
filter_dict={
"batch_id": 1,
"config.dataset.data_params.split_key": "split_ood_multi_task",
"config.model.append_ae_layer": True,
},
)
# %% pycharm={"name": "#%%\n"}
# Look at number of experiments per model
results["config.model.embedding.model"].value_counts()
# %% pycharm={"name": "#%%\n"}
pd.crosstab(
results["config.model.embedding.model"],
results["result.perturbation disentanglement"].isnull(),
)
# %% pycharm={"name": "#%%\n"}
[c for c in results.columns if "ae" in c]
# %% pycharm={"name": "#%%\n"}
pd.crosstab(
results["config.dataset.data_params.split_key"],
results["config.model.load_pretrained"],
)
# %% pycharm={"name": "#%%\n"}
pd.crosstab(
results["config.dataset.data_params.split_key"],
results["result.loss_reconstruction"].isnull(),
)
# %% pycharm={"name": "#%%\n"}
# columns
results.isnull().any()[results.isnull().any()]
# %% pycharm={"name": "#%%\n"}
# rows without nans
clean_id = results.loc[~results["result.training"].isnull(), "_id"]
# clean_id
# %% [markdown] pycharm={"name": "#%% md\n"}
# ## Preprocessing the results dataframe
# %% pycharm={"name": "#%%\n"}
# percentage of training runs that resulted in NaNs or totally failed
results_clean = results[results._id.isin(clean_id)].copy()
print(f"Percentage of invalid (nan) runs: {1 - len(clean_id) / len(results)}")
# Remove runs with r2 < 0.6 on the training set
# results_clean = results_clean[results_clean['result.training'].apply(lambda x: x[0][0])>0.6]
# %% pycharm={"name": "#%%\n"}
results_clean["config.model.embedding.model"].value_counts()
# %% pycharm={"name": "#%%\n"}
results_clean[
["config.model.load_pretrained", "config.model.append_ae_layer", "result.ood"]
]
# %% pycharm={"name": "#%%\n"}
# calculate some stats
get_mean = lambda x: np.array(x)[-1, 0]
get_mean_de = lambda x: np.array(x)[-1, 1]
results_clean["result.training_mean"] = results_clean["result.training"].apply(get_mean)
results_clean["result.training_mean_de"] = results_clean["result.training"].apply(
get_mean_de
)
results_clean["result.val_mean"] = results_clean["result.test"].apply(get_mean)
results_clean["result.val_mean_de"] = results_clean["result.test"].apply(get_mean_de)
results_clean["result.test_mean"] = results_clean["result.ood"].apply(get_mean)
results_clean["result.test_mean_de"] = results_clean["result.ood"].apply(get_mean_de)
results_clean["result.perturbation disentanglement"] = results_clean[
"result.perturbation disentanglement"
].apply(lambda x: x[0])
results_clean["result.covariate disentanglement"] = results_clean[
"result.covariate disentanglement"
].apply(lambda x: x[0][0])
results_clean["result.final_reconstruction"] = results_clean[
"result.loss_reconstruction"
].apply(lambda x: x[-1])
results_clean.head(3)
# %% pycharm={"name": "#%%\n"}
# results_clean["result.training_sc_mean"] = results_clean["result.training_sc"].apply(get_mean)
# results_clean["result.training_sc_mean_de"] = results_clean["result.training_sc"].apply(get_mean_de)
# results_clean["result.val_sc_mean"] = results_clean["result.test_sc"].apply(get_mean)
# results_clean["result.val_sc_mean_de"] = results_clean["result.test_sc"].apply(get_mean_de)
# results_clean["result.test_sc_mean"] = results_clean["result.ood_sc"].apply(get_mean)
# results_clean["result.test_sc_mean_de"] = results_clean["result.ood_sc"].apply(get_mean_de)
# results_clean = results_clean[results_clean["result.val_sc_mean"] > 0.01]
# results_clean = results_clean[results_clean["result.val_mean_de"] > 0.4]
# # results_clean = results_clean[results_clean["config.model.append_ae_layer"] == True]
# %% [markdown] pycharm={"name": "#%% md\n"}
# ## Look at early stopping
# %% pycharm={"name": "#%%\n"}
fig, ax = plt.subplots(2, 1)
sns.histplot(
data=results_clean[results_clean["config.model.load_pretrained"] == True][
"result.epoch"
].apply(max),
ax=ax[0],
)
ax[0].set_title("Total epochs before final stopping (min 125), pretrained")
ax[1] = sns.histplot(
data=results_clean[results_clean["config.model.load_pretrained"] == False][
"result.epoch"
].apply(max),
ax=ax[1],
)
ax[1].set_title("Total epochs before final stopping (min 125), non pretrained")
plt.tight_layout()
# %% [markdown] pycharm={"name": "#%% md\n"}
# ## Look at $r^2$ reconstruction
# %% pycharm={"name": "#%%\n"}
[c for c in results_clean.columns if "pretrain" in c]
results_clean[
[
"config.model.embedding.model",
"config.model.load_pretrained",
"config.model.append_ae_layer",
]
]
# %% [markdown] pycharm={"name": "#%% md\n"}
# ### DE genes
# %% pycharm={"name": "#%%\n"}
# DE genes
rows, cols = 1, 3
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows))
for i, y in enumerate(
("result.training_mean_de", "result.val_mean_de", "result.test_mean_de")
):
sns.violinplot(
data=results_clean,
x="config.model.embedding.model",
y=y,
hue="config.model.load_pretrained",
inner="points",
ax=ax[i],
scale="width",
)
# ax[i].set_ylim([0.3,1.01])
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right")
ax[i].set_xlabel("")
ax[i].set_ylabel(y.split(".")[-1])
ax[i].legend(title="Pretrained", loc="lower right", fontsize=18, title_fontsize=24)
ax[0].get_legend().remove()
ax[1].get_legend().remove()
ax[2].legend(
title="Pretrained",
fontsize=18,
title_fontsize=24,
loc="center left",
bbox_to_anchor=(1, 0.5),
)
plt.tight_layout()
# %% pycharm={"name": "#%%\n"}
# # DE genes
# rows, cols = 1, 3
# fig, ax = plt.subplots(rows, cols, figsize=(10*cols,6*rows))
# for i, y in enumerate(("result.training_sc_mean_de", "result.val_sc_mean_de", "result.test_sc_mean_de")):
# sns.violinplot(
# data=results_clean,
# x="config.model.embedding.model",
# y=y,
# hue='config.model.load_pretrained',
# inner='points',
# ax=ax[i],
# scale='width',
# )
# ax[i].set_ylim([0.0,0.5])
# ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha='right')
# ax[i].set_xlabel('')
# ax[i].set_ylabel(y.split('.')[-1])
# ax[i].legend(title='Pretrained',loc='lower right', fontsize=18, title_fontsize=24)
# ax[0].get_legend().remove()
# ax[1].get_legend().remove()
# ax[2].legend(title='Pretrained', fontsize=18, title_fontsize=24, loc="center left", bbox_to_anchor=(1, 0.5))
# plt.tight_layout()
# %% [markdown] pycharm={"name": "#%% md\n"}
# ### All genes
# %% pycharm={"name": "#%%\n"}
# DE genes
rows, cols = 1, 3
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows))
for i, y in enumerate(("result.training_mean", "result.val_mean", "result.test_mean")):
sns.violinplot(
data=results_clean,
x="config.model.embedding.model",
y=y,
hue="config.model.load_pretrained",
inner="points",
ax=ax[i],
scale="width",
)
# ax[i].set_ylim([0.3,1.01])
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right")
ax[i].set_xlabel("")
ax[i].set_ylabel(y.split(".")[-1])
ax[i].legend(title="Pretrained", loc="lower right", fontsize=18, title_fontsize=24)
ax[0].get_legend().remove()
ax[1].get_legend().remove()
ax[2].legend(
title="Pretrained",
fontsize=18,
title_fontsize=24,
loc="center left",
bbox_to_anchor=(1, 0.5),
)
plt.tight_layout()
# %% [markdown] pycharm={"name": "#%% md\n"}
# ## Look at disentanglement scores
# %% pycharm={"name": "#%%\n"}
rows = 2
cols = 1
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 7 * rows), sharex=True)
max_entangle = [0.1, 0.8]
for i, y in enumerate(
["result.perturbation disentanglement", "result.covariate disentanglement"]
):
sns.violinplot(
data=results_clean,
x="config.model.embedding.model",
y=y,
inner="point",
ax=ax[i],
hue="config.model.load_pretrained",
)
# ax[i].set_ylim([0,1])
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right")
ax[i].axhline(max_entangle[i], ls=":", color="black")
ax[i].set_xlabel("")
ax[i].set_ylabel(y.split(".")[-1])
ax[0].get_legend().remove()
ax[1].legend(
title="Pretrained",
fontsize=18,
title_fontsize=24,
loc="center left",
bbox_to_anchor=(1, 0.5),
)
plt.tight_layout()
# %% [markdown] pycharm={"name": "#%% md\n"}
# ## Subselect to disentangled models
# %% pycharm={"name": "#%%\n"}
n_top = 2
def performance_condition(emb, pretrained, max_entangle, max_entangle_cov):
cond = results_clean["config.model.embedding.model"] == emb
cond = cond & (results_clean["result.perturbation disentanglement"] < max_entangle)
cond = cond & (results_clean["result.covariate disentanglement"] < max_entangle_cov)
cond = cond & (results_clean["config.model.load_pretrained"] == pretrained)
return cond
best = []
for embedding in list(results_clean["config.model.embedding.model"].unique()):
for pretrained in [True, False]:
df = results_clean[performance_condition(embedding, pretrained, 1, 1)]
print(embedding, pretrained, len(df))
best.append(
df.sort_values(by="result.val_mean_de", ascending=False).head(n_top)
)
best = pd.concat(best)
# %% pycharm={"name": "#%%\n"}
# All genes, DE genes, disentanglement
rows, cols = 1, 4
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows))
for i, y in enumerate(
[
"result.test_mean",
"result.test_mean_de",
"result.perturbation disentanglement",
"result.covariate disentanglement",
]
):
sns.violinplot(
data=best,
x="config.model.embedding.model",
y=y,
hue="config.model.load_pretrained",
inner="points",
ax=ax[i],
scale="width",
)
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right")
ax[i].set_xlabel("")
ax[i].set_ylabel(y.split(".")[-1])
ax[i].legend(title="Pretrained", loc="lower right", fontsize=18, title_fontsize=24)
ax[0].get_legend().remove()
# ax[0].set_ylim([0.4, 1.01])
ax[1].get_legend().remove()
# ax[1].set_ylim([0.4, 1.01])
ax[2].get_legend().remove()
ax[3].legend(
title="Pretrained",
fontsize=18,
title_fontsize=24,
loc="center left",
bbox_to_anchor=(1, 0.5),
)
plt.tight_layout()
# %% pycharm={"name": "#%%\n"}
rows, cols = 1, 3
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows))
for i, y in enumerate(
[
"result.training_mean",
"result.training_mean_de",
"result.perturbation disentanglement",
]
):
sns.violinplot(
data=best,
x="config.model.embedding.model",
y=y,
hue="config.model.load_pretrained",
inner="points",
ax=ax[i],
scale="width",
)
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right")
ax[i].set_xlabel("")
ax[i].set_ylabel(y.split(".")[-1])
ax[0].get_legend().remove()
ax[0].set_ylim([0.4, 1.01])
ax[1].get_legend().remove()
ax[1].set_ylim([0.4, 1.01])
ax[2].legend(
title="Pretrained",
fontsize=18,
title_fontsize=24,
loc="center left",
bbox_to_anchor=(1, 0.5),
)
plt.tight_layout()
# %% [markdown] pycharm={"name": "#%% md\n"}
# ## Take a deeper look in the `.config` of the best performing models
# %% pycharm={"name": "#%%\n"}
[c for c in best.columns if "hash" in c]
# %% pycharm={"name": "#%%\n"}
best[
[
"config.model.load_pretrained",
"config_hash",
"result.test_mean_de",
"result.covariate disentanglement",
]
]
# %% pycharm={"name": "#%%\n"}