chemCPA / notebooks /chemCPA_Table_4.py
github-actions[bot]
HF snapshot
a48f0ae
# ---
# jupyter:
# jupytext:
# notebook_metadata_filter: -kernelspec
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.1
# ---
# %% [markdown]
# **Requirements:**
# * Trained models
#
# **Outputs:**
# * none
# ___
# # Imports
# %%
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import umap.plot
from anndata import AnnData
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from utils import (
compute_drug_embeddings,
compute_pred,
compute_pred_ctrl,
load_config,
load_dataset,
load_model,
load_smiles,
)
from chemCPA.data import load_dataset_splits
from chemCPA.paths import FIGURE_DIR, ROOT
# %%
BLACK = False
SAVEFIG = False
# %%
if BLACK:
plt.style.use("dark_background")
else:
matplotlib.style.use("fivethirtyeight")
matplotlib.style.use("seaborn-talk")
matplotlib.pyplot.rcParams["savefig.facecolor"] = "white"
sns.set_style("whitegrid")
matplotlib.rcParams["font.family"] = "monospace"
matplotlib.rcParams["figure.dpi"] = 300
sns.set_context("poster")
# %%
# %load_ext autoreload
# %autoreload 2
# %% [markdown]
# ## Plotting function
# %%
def plot_umap(
mapper,
canon_smiles_unique_sorted,
smiles_to_pathway_map,
smiles_to_drug_map,
groups=[
"Epigenetic regulation",
"Tyrosine kinase signaling",
"Cell cycle regulation",
],
ood_drugs=[],
hue=None,
ax=None,
):
# important to use the same ordering of SMILES as was used for getting the embedding!
if ax == None:
fig, ax = plt.subplots(figsize=(12, 8))
# groups=["Tyrosine kinase signaling"]
pathway = []
drugs = []
for s in canon_smiles_unique_sorted:
if s in smiles_to_pathway_map:
pathway.append(smiles_to_pathway_map[s])
drugs.append(smiles_to_drug_map[s])
else:
pathway.append("other")
drugs.append("unknown")
pathway = np.where(pd.Series(pathway).isin(groups), pathway, "other")
# cmap = [(0.7,0.7,0.7)]
# # cmap.extend(list(plt.get_cmap('tab20').colors))
# # cmap = tuple(cmap)
# sns.scatterplot(x=mapper[:,0], y=mapper[:,1], hue=pathway, palette=cmap, ax=ax)
if hue is None:
cond = pathway != "other"
sns.scatterplot(
x=mapper[cond, 0],
y=mapper[cond, 1],
hue=pathway[cond],
ax=ax,
palette="tab10" if BLACK else None,
)
sns.scatterplot(
x=mapper[~cond, 0], y=mapper[~cond, 1], ax=ax, color="grey", alpha=0.3
)
else:
sns.scatterplot(x=mapper[:, 0], y=mapper[:, 1], hue=hue, ax=ax)
shift = 0.75
for i, label in enumerate(drugs):
if drugs[i] in ood_drugs:
ax.text(
x=mapper[i, 0] + shift,
y=mapper[i, 1] + shift,
s=label,
fontdict=dict(
color="white" if BLACK else "black", alpha=1, size=16, weight=600
),
bbox=dict(facecolor="black" if BLACK else "lightgrey", alpha=0.3),
)
x = mapper[i, 0]
y = mapper[i, 1]
ax.annotate(
"",
xy=(x, y),
xytext=(x + shift, y + shift),
arrowprops=dict(
arrowstyle="->",
# facecolor='black',
edgecolor="black",
),
# fontdict=dict(color='white' if BLACK else 'black', alpha=1, size=12, weight=600),
# bbox=dict(facecolor='black' if BLACK else 'lightgrey', alpha=0.3)
)
# %% [markdown]
# # Load and analyse model
# * Define `seml_collection` and `model_hash` to load data and model
# %%
# Extended
seml_collection = "multi_task"
model_hash_pretrained_rdkit = "dde01c1c58f398d524453c4b564a440f" # Fine-tuned
model_hash_scratch_rdkit = "475e26950b2c531bea88931a4b2c27b7" # Non-pretrained
model_hash_pretrained_grover = "0f4a3b11e1fbe3da58125f39ff6fb035" # Fine-tuned
model_hash_scratch_grover = "b372147c80cf9ad4bd10d16bc56b7534" # Non-pretrained
model_hash_pretrained_jtvae = "e4eac660c5830245f681ec3cc5099f21" # Fine-tuned
model_hash_scratch_jtvae = "6b465400467f69da861e3ef0b4709e19" # Non-pretrained
# SHARED
seml_collection = "multi_task"
model_hash_pretrained_rdkit = "c824e42f7ce751cf9a8ed26f0d9e0af7" # Fine-tuned
model_hash_scratch_rdkit = "59bdaefb1c1adfaf2976e3fdf62afa21" # Non-pretrained
model_hash_pretrained_grover = "c30016a7469feb78a8ee9ebb18ed9b1f" # Fine-tuned
model_hash_scratch_grover = "60e4b40e8d67bff2d5efc5e22e265820" # Non-pretrained
model_hash_pretrained_jtvae = "915345a522c29fa709b995d6149083b9" # Fine-tuned
model_hash_scratch_jtvae = "934c89b742a6309ad6bb2e1cf90c5e50" # Non-pretrained
# %%
model_hash_pretrained = model_hash_pretrained_rdkit
# %% [markdown]
# ___
# ## Pretrained model
# %% [markdown]
# ### Load model
# %%
config = load_config(seml_collection, model_hash_pretrained)
config["dataset"]["data_params"]["dataset_path"] = (
ROOT / config["dataset"]["data_params"]["dataset_path"]
)
config["model"]["embedding"]["directory"] = (
ROOT / config["model"]["embedding"]["directory"]
)
dataset, key_dict = load_dataset(config)
config["dataset"]["n_vars"] = dataset.n_vars
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(
config, dataset, key_dict, True
)
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted)
# %% [markdown]
# #### Define which drugs should be annotaded with list `ood_drugs`
# %%
ood_drugs = (
dataset.obs.condition[dataset.obs.split_ood_multi_task.isin(["ood"])]
.unique()
.to_list()
)
# %%
ood_drugs
# %% [markdown]
# #### Get pathway level 2 annotation for clustering of drug embeddings
# %%
smiles_to_pw_level2_map = {}
pw1_to_pw2 = {}
for (drug, pw1, pw2), df in dataset.obs.groupby(
["SMILES", "pathway_level_1", "pathway_level_2"]
):
smiles_to_pw_level2_map[drug] = pw2
if pw1 in pw1_to_pw2:
pw1_to_pw2[pw1].add(pw2)
else:
pw1_to_pw2[pw1] = {pw2}
# %% [markdown]
# # Compute drug perturbation (RDKit)
# %%
transf_embeddings_pretrained_high = compute_drug_embeddings(
model_pretrained, embedding_pretrained, dosage=1e4
)
mapper_pretrained_high = umap.UMAP(
n_neighbors=25, min_dist=1, spread=2, metric="euclidean"
).fit_transform(transf_embeddings_pretrained_high)
transf_embeddings_pretrained_low = compute_drug_embeddings(
model_pretrained, embedding_pretrained, dosage=10
)
mapper_pretrained_low = umap.UMAP(n_neighbors=25, min_dist=1, spread=2).fit_transform(
transf_embeddings_pretrained_low
)
# %% [markdown]
# # Plot pretrained RDKit fingerprints
# %%
drug_emb_pretrained = np.array(embedding_pretrained.weight.data)
# %%
for s, pw in smiles_to_pw_level2_map.items():
if pw == "Histone deacetylation":
smiles_to_pathway_map[s] = pw
# %%
SAVEFIG = True
drugs_label = [
"Quisinostat",
# 'Hesperadin',
# 'Flavopiridol',
"Belinostat",
# 'Alvespimycin',
# 'TAK-901',
"Dacinostat",
# 'Tanespimycin',
"Givinostat",
]
_groups = [
"Histone deacetylation",
"Cell cycle regulation",
"Tyrosine kinase signaling",
]
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
plot_umap(
TSNE().fit_transform(drug_emb_pretrained),
canon_smiles_unique_sorted,
smiles_to_pathway_map,
smiles_to_drug_map,
groups=_groups,
ood_drugs=drugs_label,
ax=ax,
)
ax.legend(
title="Pathway",
fontsize=18,
title_fontsize=22,
loc="upper left",
bbox_to_anchor=(1, 1),
)
plt.tight_layout()
if SAVEFIG:
if BLACK:
plt.savefig(
FIGURE_DIR / "rdkit_pretrained_drug_embedding_balck.png", format="png"
)
else:
plt.savefig(FIGURE_DIR / "rdkit_pretrained_drug_embedding.png", format="png")
SAVEFIG = False
# %% [markdown]
# _____
# %% [markdown]
# # Uncertainty
# %%
drug_emb = transf_embeddings_pretrained_high
adata = AnnData(drug_emb.numpy())
adata.X = StandardScaler().fit_transform(drug_emb)
drugs = [smiles_to_drug_map[s] for s in canon_smiles_unique_sorted]
adata.obs["drug"] = drugs
adata.obs["pathway"] = [smiles_to_pathway_map[s] for s in canon_smiles_unique_sorted]
adata.obs["split"] = "train"
adata.obs.loc[adata.obs.drug.isin(ood_drugs), "split"] = "ood"
# %% [markdown]
# Preprocess adata object
# %%
sc.pp.pca(adata, n_comps=15)
# %% [markdown]
# Baseline score
# %%
data_params = config["dataset"]["data_params"]
datasets, dataset_all = load_dataset_splits(**data_params, return_dataset=True)
# %%
dosages = [1e1, 1e2, 1e3, 1e4]
dosages = [1e4]
cell_lines = ["A549", "K562", "MCF7"]
use_DEGs = True
# %%
drug_r2_baseline_degs, _ = compute_pred_ctrl(
dataset=dataset_all, # dataset_test_ood,
dataset_ctrl=datasets["test_control"],
dosages=dosages,
cell_lines=cell_lines,
use_DEGs=True,
verbose=False,
)
# %%
baseline_key = pd.Series(drug_r2_baseline_degs.keys())
baseline_val = pd.Series(drug_r2_baseline_degs.values())
adata.obs["baseline_perf"] = ""
for d in adata.obs.drug:
cond = baseline_key.apply(lambda s: d in s)
adata.obs.loc[adata.obs.drug == d, "baseline_perf"] = baseline_val[cond].mean()
# baseline_val
# %% [markdown]
# Pretrained model score
# %%
drug_r2_pretrained_degs, _ = compute_pred(
model_pretrained,
dataset_all,
genes_control=datasets["test_control"].genes,
dosages=dosages,
cell_lines=cell_lines,
use_DEGs=True,
verbose=False,
)
# %%
pretrained_key = pd.Series(drug_r2_pretrained_degs.keys())
pretrained_val = pd.Series(drug_r2_pretrained_degs.values())
adata.obs["pretrained_perf"] = ""
for d in adata.obs.drug:
cond = pretrained_key.apply(lambda s: d in s)
adata.obs.loc[adata.obs.drug == d, "pretrained_perf"] = pretrained_val[cond].mean()
# pretrained_val
# %%
adata.obs["difference"] = adata.obs.pretrained_perf - adata.obs.baseline_perf
# %% [markdown]
# Compute knn-graph, neighbours, and distances
# %%
sc.pp.neighbors(adata, n_neighbors=10)
# %%
neighbours = np.where(adata.uns["neighbors"]["distances"].A != 0)
distances = adata.obsp["distances"].A
# %%
def entropy(column, base=None):
vc = pd.Series(column).value_counts(normalize=True, sort=False)
base = np.exp if base is None else base
return -(vc * np.log(vc) / np.log(base)).sum()
# %%
adata.obs["n_neighbours"] = ""
adata.obs["pathway_acc"] = ""
adata.obs["neighbor_pred"] = ""
adata.obs["correct"] = ""
for i in range(188):
cond = neighbours[0] == i
pathway = adata.obs.pathway.iloc[i]
pathways = adata.obs.pathway.iloc[neighbours[1][cond]]
baselines = adata.obs.baseline_perf.iloc[neighbours[1][cond]]
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "n_neighbours"] = len(
pathways
)
# adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i],'pathway_acc'] = (pathway==pathways).sum()/len(pathways)
pw, count = np.unique(pathways, return_counts=True)
# adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], 'neighbor_pred'] = pw[count.argmax()]
# adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], 'pw_pred_count'] = count.max()/len(pathways)
# adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], 'correct'] = 1 if pw[count.argmax()]==pathway else 0
# adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], 'baseline_neighbors'] = baselines.mean()
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "heterogeneity"] = len(pw)
# if adata.obs.drug.iloc[i] in ood_drugs:
# print(distances[i].sum())
# print(entropy(pathways, base=2))
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "uncertainty"] = (
1 / np.log(distances[i].sum()) * entropy(pathways, base=2)
)
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "distance"] = distances[
i
].sum()
# %% [markdown]
# ## Table 4
# %%
cols = ["drug", "distance", "difference", "uncertainty"]
print(
adata.obs.loc[adata.obs.drug.isin(ood_drugs), cols]
.sort_values("uncertainty")
.to_latex()
)
# %% [markdown]
# ## Supplement Figure 14
# %%
SAVEFIG = False
fig, ax = plt.subplots(figsize=(12, 5))
sns.scatterplot(
data=adata.obs,
x="difference",
y="uncertainty",
ax=ax,
hue="baseline_perf",
alpha=0.9,
)
ax.set_xlabel(r"Improvement of $r^2$ over baseline at 10$\mu$M")
ax.set_ylabel("Uncertainty")
ax.legend(title=r"Baseline $r^2$", fontsize=18, title_fontsize=22, loc="best")
plt.tight_layout()
if SAVEFIG:
if BLACK:
plt.savefig(FIGURE_DIR / "uncertainty_drug_embedding_balck.png", format="png")
else:
plt.savefig(FIGURE_DIR / "uncertainty_drug_embedding.png", format="png")
SAVEFIG = False
# %%
adata.obsm["umap"] = mapper_pretrained_high
# %%
SAVEFIG = True
groups = [
"Histone deacetylation",
"Tyrosine kinase signaling",
"Cell cycle regulation",
"DNA damage & DNA repair",
]
fig, ax = plt.subplots(1, 2, figsize=(19, 5), sharex=True)
plot_umap(
mapper_pretrained_high,
canon_smiles_unique_sorted,
smiles_to_pathway_map,
smiles_to_drug_map,
groups=groups,
ood_drugs=ood_drugs,
ax=ax[0],
)
uncertainty = adata.obs["uncertainty"].values
plot_umap(
mapper_pretrained_high,
canon_smiles_unique_sorted,
smiles_to_pathway_map,
smiles_to_drug_map,
groups=[""],
hue=uncertainty,
ood_drugs=ood_drugs,
ax=ax[1],
)
ax[1].set(xticklabels=[], yticklabels=[])
# ax[1].set_xlabel(f"UMAP of $z_d$ for a dosage of $10\,\mu$M")
ax[1].legend(
title="Uncertainty",
labels=["high", "low"],
loc="lower left",
bbox_to_anchor=(0.95, 0.1),
bbox_transform=fig.transFigure,
)
ax[0].set(xticklabels=[], yticklabels=[])
ax[0].grid(False)
ax[1].grid(False)
# ax[0].get_legend().remove()
ax[0].legend(
title="Pathway",
fontsize=18,
title_fontsize=22,
loc="upper left",
bbox_to_anchor=(0.95, 0.9),
bbox_transform=fig.transFigure,
)
# plt.tight_layout()
plt.subplots_adjust(wspace=0.05, hspace=0.05)
# fig.supxlabel(f"UMAP of $z_d$ for a dosage of $10\,\mu$M")
fig.text(
0.5, 0.0, f"UMAP of $z_d$ for a dosage of $10\,\mu$M", ha="center", va="center"
)
if SAVEFIG:
if BLACK:
plt.savefig(
FIGURE_DIR / "umap_drug_embedding_uncertainty_black.png", format="png"
)
else:
plt.savefig(
FIGURE_DIR / "umap_drug_embedding_uncertainty.png",
format="png",
bbox_inches="tight",
)
# %%