|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
|
import matplotlib |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
import seml |
|
from matplotlib import pyplot as plt |
|
|
|
matplotlib.style.use("fivethirtyeight") |
|
matplotlib.style.use("seaborn-talk") |
|
matplotlib.rcParams["font.family"] = "monospace" |
|
plt.rcParams["savefig.facecolor"] = "white" |
|
sns.set_context("poster") |
|
pd.set_option("display.max_columns", 100) |
|
|
|
|
|
results = seml.get_results( |
|
"check_seml", |
|
to_data_frame=True, |
|
fields=["config", "result", "seml", "config_hash"], |
|
states=["COMPLETED"], |
|
filter_dict={ |
|
"batch_id": 1, |
|
"config.dataset.data_params.split_key": "split_ood_multi_task", |
|
"config.model.append_ae_layer": True, |
|
}, |
|
) |
|
|
|
|
|
|
|
results["config.model.embedding.model"].value_counts() |
|
|
|
|
|
pd.crosstab( |
|
results["config.model.embedding.model"], |
|
results["result.perturbation disentanglement"].isnull(), |
|
) |
|
|
|
|
|
[c for c in results.columns if "ae" in c] |
|
|
|
|
|
pd.crosstab( |
|
results["config.dataset.data_params.split_key"], |
|
results["config.model.load_pretrained"], |
|
) |
|
|
|
|
|
pd.crosstab( |
|
results["config.dataset.data_params.split_key"], |
|
results["result.loss_reconstruction"].isnull(), |
|
) |
|
|
|
|
|
|
|
results.isnull().any()[results.isnull().any()] |
|
|
|
|
|
|
|
clean_id = results.loc[~results["result.training"].isnull(), "_id"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results_clean = results[results._id.isin(clean_id)].copy() |
|
print(f"Percentage of invalid (nan) runs: {1 - len(clean_id) / len(results)}") |
|
|
|
|
|
|
|
|
|
|
|
results_clean["config.model.embedding.model"].value_counts() |
|
|
|
|
|
results_clean[ |
|
["config.model.load_pretrained", "config.model.append_ae_layer", "result.ood"] |
|
] |
|
|
|
|
|
|
|
get_mean = lambda x: np.array(x)[-1, 0] |
|
get_mean_de = lambda x: np.array(x)[-1, 1] |
|
|
|
results_clean["result.training_mean"] = results_clean["result.training"].apply(get_mean) |
|
results_clean["result.training_mean_de"] = results_clean["result.training"].apply( |
|
get_mean_de |
|
) |
|
results_clean["result.val_mean"] = results_clean["result.test"].apply(get_mean) |
|
results_clean["result.val_mean_de"] = results_clean["result.test"].apply(get_mean_de) |
|
results_clean["result.test_mean"] = results_clean["result.ood"].apply(get_mean) |
|
results_clean["result.test_mean_de"] = results_clean["result.ood"].apply(get_mean_de) |
|
results_clean["result.perturbation disentanglement"] = results_clean[ |
|
"result.perturbation disentanglement" |
|
].apply(lambda x: x[0]) |
|
results_clean["result.covariate disentanglement"] = results_clean[ |
|
"result.covariate disentanglement" |
|
].apply(lambda x: x[0][0]) |
|
results_clean["result.final_reconstruction"] = results_clean[ |
|
"result.loss_reconstruction" |
|
].apply(lambda x: x[-1]) |
|
|
|
results_clean.head(3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(2, 1) |
|
sns.histplot( |
|
data=results_clean[results_clean["config.model.load_pretrained"] == True][ |
|
"result.epoch" |
|
].apply(max), |
|
ax=ax[0], |
|
) |
|
ax[0].set_title("Total epochs before final stopping (min 125), pretrained") |
|
|
|
ax[1] = sns.histplot( |
|
data=results_clean[results_clean["config.model.load_pretrained"] == False][ |
|
"result.epoch" |
|
].apply(max), |
|
ax=ax[1], |
|
) |
|
ax[1].set_title("Total epochs before final stopping (min 125), non pretrained") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
|
|
|
[c for c in results_clean.columns if "pretrain" in c] |
|
|
|
results_clean[ |
|
[ |
|
"config.model.embedding.model", |
|
"config.model.load_pretrained", |
|
"config.model.append_ae_layer", |
|
] |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
rows, cols = 1, 3 |
|
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows)) |
|
|
|
for i, y in enumerate( |
|
("result.training_mean_de", "result.val_mean_de", "result.test_mean_de") |
|
): |
|
sns.violinplot( |
|
data=results_clean, |
|
x="config.model.embedding.model", |
|
y=y, |
|
hue="config.model.load_pretrained", |
|
inner="points", |
|
ax=ax[i], |
|
scale="width", |
|
) |
|
|
|
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right") |
|
ax[i].set_xlabel("") |
|
ax[i].set_ylabel(y.split(".")[-1]) |
|
ax[i].legend(title="Pretrained", loc="lower right", fontsize=18, title_fontsize=24) |
|
|
|
ax[0].get_legend().remove() |
|
ax[1].get_legend().remove() |
|
ax[2].legend( |
|
title="Pretrained", |
|
fontsize=18, |
|
title_fontsize=24, |
|
loc="center left", |
|
bbox_to_anchor=(1, 0.5), |
|
) |
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rows, cols = 1, 3 |
|
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows)) |
|
|
|
for i, y in enumerate(("result.training_mean", "result.val_mean", "result.test_mean")): |
|
sns.violinplot( |
|
data=results_clean, |
|
x="config.model.embedding.model", |
|
y=y, |
|
hue="config.model.load_pretrained", |
|
inner="points", |
|
ax=ax[i], |
|
scale="width", |
|
) |
|
|
|
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right") |
|
ax[i].set_xlabel("") |
|
ax[i].set_ylabel(y.split(".")[-1]) |
|
ax[i].legend(title="Pretrained", loc="lower right", fontsize=18, title_fontsize=24) |
|
|
|
ax[0].get_legend().remove() |
|
ax[1].get_legend().remove() |
|
ax[2].legend( |
|
title="Pretrained", |
|
fontsize=18, |
|
title_fontsize=24, |
|
loc="center left", |
|
bbox_to_anchor=(1, 0.5), |
|
) |
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
|
|
|
rows = 2 |
|
cols = 1 |
|
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 7 * rows), sharex=True) |
|
|
|
max_entangle = [0.1, 0.8] |
|
for i, y in enumerate( |
|
["result.perturbation disentanglement", "result.covariate disentanglement"] |
|
): |
|
sns.violinplot( |
|
data=results_clean, |
|
x="config.model.embedding.model", |
|
y=y, |
|
inner="point", |
|
ax=ax[i], |
|
hue="config.model.load_pretrained", |
|
) |
|
|
|
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right") |
|
ax[i].axhline(max_entangle[i], ls=":", color="black") |
|
ax[i].set_xlabel("") |
|
ax[i].set_ylabel(y.split(".")[-1]) |
|
ax[0].get_legend().remove() |
|
ax[1].legend( |
|
title="Pretrained", |
|
fontsize=18, |
|
title_fontsize=24, |
|
loc="center left", |
|
bbox_to_anchor=(1, 0.5), |
|
) |
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
|
|
|
n_top = 2 |
|
|
|
|
|
def performance_condition(emb, pretrained, max_entangle, max_entangle_cov): |
|
cond = results_clean["config.model.embedding.model"] == emb |
|
cond = cond & (results_clean["result.perturbation disentanglement"] < max_entangle) |
|
cond = cond & (results_clean["result.covariate disentanglement"] < max_entangle_cov) |
|
cond = cond & (results_clean["config.model.load_pretrained"] == pretrained) |
|
return cond |
|
|
|
|
|
best = [] |
|
for embedding in list(results_clean["config.model.embedding.model"].unique()): |
|
for pretrained in [True, False]: |
|
df = results_clean[performance_condition(embedding, pretrained, 1, 1)] |
|
print(embedding, pretrained, len(df)) |
|
best.append( |
|
df.sort_values(by="result.val_mean_de", ascending=False).head(n_top) |
|
) |
|
|
|
best = pd.concat(best) |
|
|
|
|
|
|
|
rows, cols = 1, 4 |
|
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows)) |
|
|
|
for i, y in enumerate( |
|
[ |
|
"result.test_mean", |
|
"result.test_mean_de", |
|
"result.perturbation disentanglement", |
|
"result.covariate disentanglement", |
|
] |
|
): |
|
sns.violinplot( |
|
data=best, |
|
x="config.model.embedding.model", |
|
y=y, |
|
hue="config.model.load_pretrained", |
|
inner="points", |
|
ax=ax[i], |
|
scale="width", |
|
) |
|
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right") |
|
ax[i].set_xlabel("") |
|
ax[i].set_ylabel(y.split(".")[-1]) |
|
ax[i].legend(title="Pretrained", loc="lower right", fontsize=18, title_fontsize=24) |
|
ax[0].get_legend().remove() |
|
|
|
ax[1].get_legend().remove() |
|
|
|
ax[2].get_legend().remove() |
|
ax[3].legend( |
|
title="Pretrained", |
|
fontsize=18, |
|
title_fontsize=24, |
|
loc="center left", |
|
bbox_to_anchor=(1, 0.5), |
|
) |
|
plt.tight_layout() |
|
|
|
|
|
|
|
rows, cols = 1, 3 |
|
fig, ax = plt.subplots(rows, cols, figsize=(10 * cols, 6 * rows)) |
|
|
|
for i, y in enumerate( |
|
[ |
|
"result.training_mean", |
|
"result.training_mean_de", |
|
"result.perturbation disentanglement", |
|
] |
|
): |
|
sns.violinplot( |
|
data=best, |
|
x="config.model.embedding.model", |
|
y=y, |
|
hue="config.model.load_pretrained", |
|
inner="points", |
|
ax=ax[i], |
|
scale="width", |
|
) |
|
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=75, ha="right") |
|
ax[i].set_xlabel("") |
|
ax[i].set_ylabel(y.split(".")[-1]) |
|
ax[0].get_legend().remove() |
|
ax[0].set_ylim([0.4, 1.01]) |
|
ax[1].get_legend().remove() |
|
ax[1].set_ylim([0.4, 1.01]) |
|
ax[2].legend( |
|
title="Pretrained", |
|
fontsize=18, |
|
title_fontsize=24, |
|
loc="center left", |
|
bbox_to_anchor=(1, 0.5), |
|
) |
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
|
|
|
[c for c in best.columns if "hash" in c] |
|
|
|
|
|
best[ |
|
[ |
|
"config.model.load_pretrained", |
|
"config_hash", |
|
"result.test_mean_de", |
|
"result.covariate disentanglement", |
|
] |
|
] |
|
|
|
|
|
|