|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import scanpy as sc |
|
import seaborn as sns |
|
import umap.plot |
|
from anndata import AnnData |
|
from sklearn.manifold import TSNE |
|
from sklearn.preprocessing import StandardScaler |
|
from utils import ( |
|
compute_drug_embeddings, |
|
compute_pred, |
|
compute_pred_ctrl, |
|
load_config, |
|
load_dataset, |
|
load_model, |
|
load_smiles, |
|
) |
|
|
|
from chemCPA.data import load_dataset_splits |
|
from chemCPA.paths import FIGURE_DIR, ROOT |
|
|
|
|
|
BLACK = False |
|
SAVEFIG = False |
|
|
|
|
|
if BLACK: |
|
plt.style.use("dark_background") |
|
else: |
|
matplotlib.style.use("fivethirtyeight") |
|
matplotlib.style.use("seaborn-talk") |
|
matplotlib.pyplot.rcParams["savefig.facecolor"] = "white" |
|
sns.set_style("whitegrid") |
|
|
|
matplotlib.rcParams["font.family"] = "monospace" |
|
matplotlib.rcParams["figure.dpi"] = 300 |
|
sns.set_context("poster") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_umap( |
|
mapper, |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=[ |
|
"Epigenetic regulation", |
|
"Tyrosine kinase signaling", |
|
"Cell cycle regulation", |
|
], |
|
ood_drugs=[], |
|
hue=None, |
|
ax=None, |
|
): |
|
|
|
if ax == None: |
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
pathway = [] |
|
drugs = [] |
|
|
|
for s in canon_smiles_unique_sorted: |
|
if s in smiles_to_pathway_map: |
|
pathway.append(smiles_to_pathway_map[s]) |
|
drugs.append(smiles_to_drug_map[s]) |
|
else: |
|
pathway.append("other") |
|
drugs.append("unknown") |
|
|
|
pathway = np.where(pd.Series(pathway).isin(groups), pathway, "other") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if hue is None: |
|
cond = pathway != "other" |
|
sns.scatterplot( |
|
x=mapper[cond, 0], |
|
y=mapper[cond, 1], |
|
hue=pathway[cond], |
|
ax=ax, |
|
palette="tab10" if BLACK else None, |
|
) |
|
sns.scatterplot( |
|
x=mapper[~cond, 0], y=mapper[~cond, 1], ax=ax, color="grey", alpha=0.3 |
|
) |
|
else: |
|
sns.scatterplot(x=mapper[:, 0], y=mapper[:, 1], hue=hue, ax=ax) |
|
|
|
shift = 0.75 |
|
for i, label in enumerate(drugs): |
|
if drugs[i] in ood_drugs: |
|
ax.text( |
|
x=mapper[i, 0] + shift, |
|
y=mapper[i, 1] + shift, |
|
s=label, |
|
fontdict=dict( |
|
color="white" if BLACK else "black", alpha=1, size=16, weight=600 |
|
), |
|
bbox=dict(facecolor="black" if BLACK else "lightgrey", alpha=0.3), |
|
) |
|
x = mapper[i, 0] |
|
y = mapper[i, 1] |
|
ax.annotate( |
|
"", |
|
xy=(x, y), |
|
xytext=(x + shift, y + shift), |
|
arrowprops=dict( |
|
arrowstyle="->", |
|
|
|
edgecolor="black", |
|
), |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seml_collection = "multi_task" |
|
|
|
model_hash_pretrained_rdkit = "dde01c1c58f398d524453c4b564a440f" |
|
model_hash_scratch_rdkit = "475e26950b2c531bea88931a4b2c27b7" |
|
|
|
model_hash_pretrained_grover = "0f4a3b11e1fbe3da58125f39ff6fb035" |
|
model_hash_scratch_grover = "b372147c80cf9ad4bd10d16bc56b7534" |
|
|
|
model_hash_pretrained_jtvae = "e4eac660c5830245f681ec3cc5099f21" |
|
model_hash_scratch_jtvae = "6b465400467f69da861e3ef0b4709e19" |
|
|
|
|
|
|
|
seml_collection = "multi_task" |
|
|
|
model_hash_pretrained_rdkit = "c824e42f7ce751cf9a8ed26f0d9e0af7" |
|
model_hash_scratch_rdkit = "59bdaefb1c1adfaf2976e3fdf62afa21" |
|
|
|
model_hash_pretrained_grover = "c30016a7469feb78a8ee9ebb18ed9b1f" |
|
model_hash_scratch_grover = "60e4b40e8d67bff2d5efc5e22e265820" |
|
|
|
model_hash_pretrained_jtvae = "915345a522c29fa709b995d6149083b9" |
|
model_hash_scratch_jtvae = "934c89b742a6309ad6bb2e1cf90c5e50" |
|
|
|
|
|
model_hash_pretrained = model_hash_pretrained_rdkit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_pretrained) |
|
|
|
config["dataset"]["data_params"]["dataset_path"] = ( |
|
ROOT / config["dataset"]["data_params"]["dataset_path"] |
|
) |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
|
|
dataset, key_dict = load_dataset(config) |
|
|
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
|
|
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles( |
|
config, dataset, key_dict, True |
|
) |
|
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted) |
|
|
|
|
|
|
|
|
|
|
|
ood_drugs = ( |
|
dataset.obs.condition[dataset.obs.split_ood_multi_task.isin(["ood"])] |
|
.unique() |
|
.to_list() |
|
) |
|
|
|
|
|
ood_drugs |
|
|
|
|
|
|
|
|
|
|
|
smiles_to_pw_level2_map = {} |
|
pw1_to_pw2 = {} |
|
|
|
for (drug, pw1, pw2), df in dataset.obs.groupby( |
|
["SMILES", "pathway_level_1", "pathway_level_2"] |
|
): |
|
smiles_to_pw_level2_map[drug] = pw2 |
|
if pw1 in pw1_to_pw2: |
|
pw1_to_pw2[pw1].add(pw2) |
|
else: |
|
pw1_to_pw2[pw1] = {pw2} |
|
|
|
|
|
|
|
|
|
|
|
transf_embeddings_pretrained_high = compute_drug_embeddings( |
|
model_pretrained, embedding_pretrained, dosage=1e4 |
|
) |
|
mapper_pretrained_high = umap.UMAP( |
|
n_neighbors=25, min_dist=1, spread=2, metric="euclidean" |
|
).fit_transform(transf_embeddings_pretrained_high) |
|
|
|
transf_embeddings_pretrained_low = compute_drug_embeddings( |
|
model_pretrained, embedding_pretrained, dosage=10 |
|
) |
|
mapper_pretrained_low = umap.UMAP(n_neighbors=25, min_dist=1, spread=2).fit_transform( |
|
transf_embeddings_pretrained_low |
|
) |
|
|
|
|
|
|
|
|
|
|
|
drug_emb_pretrained = np.array(embedding_pretrained.weight.data) |
|
|
|
|
|
for s, pw in smiles_to_pw_level2_map.items(): |
|
if pw == "Histone deacetylation": |
|
smiles_to_pathway_map[s] = pw |
|
|
|
|
|
SAVEFIG = True |
|
|
|
drugs_label = [ |
|
"Quisinostat", |
|
|
|
|
|
"Belinostat", |
|
|
|
|
|
"Dacinostat", |
|
|
|
"Givinostat", |
|
] |
|
|
|
_groups = [ |
|
"Histone deacetylation", |
|
"Cell cycle regulation", |
|
"Tyrosine kinase signaling", |
|
] |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(15, 5)) |
|
|
|
plot_umap( |
|
TSNE().fit_transform(drug_emb_pretrained), |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=_groups, |
|
ood_drugs=drugs_label, |
|
ax=ax, |
|
) |
|
ax.legend( |
|
title="Pathway", |
|
fontsize=18, |
|
title_fontsize=22, |
|
loc="upper left", |
|
bbox_to_anchor=(1, 1), |
|
) |
|
plt.tight_layout() |
|
|
|
if SAVEFIG: |
|
if BLACK: |
|
plt.savefig( |
|
FIGURE_DIR / "rdkit_pretrained_drug_embedding_balck.png", format="png" |
|
) |
|
else: |
|
plt.savefig(FIGURE_DIR / "rdkit_pretrained_drug_embedding.png", format="png") |
|
|
|
SAVEFIG = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
drug_emb = transf_embeddings_pretrained_high |
|
|
|
adata = AnnData(drug_emb.numpy()) |
|
adata.X = StandardScaler().fit_transform(drug_emb) |
|
|
|
|
|
drugs = [smiles_to_drug_map[s] for s in canon_smiles_unique_sorted] |
|
adata.obs["drug"] = drugs |
|
adata.obs["pathway"] = [smiles_to_pathway_map[s] for s in canon_smiles_unique_sorted] |
|
adata.obs["split"] = "train" |
|
adata.obs.loc[adata.obs.drug.isin(ood_drugs), "split"] = "ood" |
|
|
|
|
|
|
|
|
|
|
|
sc.pp.pca(adata, n_comps=15) |
|
|
|
|
|
|
|
|
|
|
|
data_params = config["dataset"]["data_params"] |
|
datasets, dataset_all = load_dataset_splits(**data_params, return_dataset=True) |
|
|
|
|
|
dosages = [1e1, 1e2, 1e3, 1e4] |
|
dosages = [1e4] |
|
cell_lines = ["A549", "K562", "MCF7"] |
|
use_DEGs = True |
|
|
|
|
|
drug_r2_baseline_degs, _ = compute_pred_ctrl( |
|
dataset=dataset_all, |
|
dataset_ctrl=datasets["test_control"], |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
|
|
|
|
baseline_key = pd.Series(drug_r2_baseline_degs.keys()) |
|
baseline_val = pd.Series(drug_r2_baseline_degs.values()) |
|
|
|
adata.obs["baseline_perf"] = "" |
|
|
|
for d in adata.obs.drug: |
|
cond = baseline_key.apply(lambda s: d in s) |
|
adata.obs.loc[adata.obs.drug == d, "baseline_perf"] = baseline_val[cond].mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
drug_r2_pretrained_degs, _ = compute_pred( |
|
model_pretrained, |
|
dataset_all, |
|
genes_control=datasets["test_control"].genes, |
|
dosages=dosages, |
|
cell_lines=cell_lines, |
|
use_DEGs=True, |
|
verbose=False, |
|
) |
|
|
|
|
|
pretrained_key = pd.Series(drug_r2_pretrained_degs.keys()) |
|
pretrained_val = pd.Series(drug_r2_pretrained_degs.values()) |
|
|
|
adata.obs["pretrained_perf"] = "" |
|
|
|
for d in adata.obs.drug: |
|
cond = pretrained_key.apply(lambda s: d in s) |
|
adata.obs.loc[adata.obs.drug == d, "pretrained_perf"] = pretrained_val[cond].mean() |
|
|
|
|
|
|
|
|
|
adata.obs["difference"] = adata.obs.pretrained_perf - adata.obs.baseline_perf |
|
|
|
|
|
|
|
|
|
|
|
sc.pp.neighbors(adata, n_neighbors=10) |
|
|
|
|
|
neighbours = np.where(adata.uns["neighbors"]["distances"].A != 0) |
|
distances = adata.obsp["distances"].A |
|
|
|
|
|
|
|
def entropy(column, base=None): |
|
vc = pd.Series(column).value_counts(normalize=True, sort=False) |
|
base = np.exp if base is None else base |
|
return -(vc * np.log(vc) / np.log(base)).sum() |
|
|
|
|
|
|
|
adata.obs["n_neighbours"] = "" |
|
adata.obs["pathway_acc"] = "" |
|
adata.obs["neighbor_pred"] = "" |
|
adata.obs["correct"] = "" |
|
|
|
for i in range(188): |
|
cond = neighbours[0] == i |
|
pathway = adata.obs.pathway.iloc[i] |
|
pathways = adata.obs.pathway.iloc[neighbours[1][cond]] |
|
baselines = adata.obs.baseline_perf.iloc[neighbours[1][cond]] |
|
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "n_neighbours"] = len( |
|
pathways |
|
) |
|
|
|
pw, count = np.unique(pathways, return_counts=True) |
|
|
|
|
|
|
|
|
|
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "heterogeneity"] = len(pw) |
|
|
|
|
|
|
|
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "uncertainty"] = ( |
|
1 / np.log(distances[i].sum()) * entropy(pathways, base=2) |
|
) |
|
adata.obs.loc[adata.obs.drug == adata.obs.drug.iloc[i], "distance"] = distances[ |
|
i |
|
].sum() |
|
|
|
|
|
|
|
|
|
|
|
cols = ["drug", "distance", "difference", "uncertainty"] |
|
|
|
print( |
|
adata.obs.loc[adata.obs.drug.isin(ood_drugs), cols] |
|
.sort_values("uncertainty") |
|
.to_latex() |
|
) |
|
|
|
|
|
|
|
|
|
|
|
SAVEFIG = False |
|
fig, ax = plt.subplots(figsize=(12, 5)) |
|
sns.scatterplot( |
|
data=adata.obs, |
|
x="difference", |
|
y="uncertainty", |
|
ax=ax, |
|
hue="baseline_perf", |
|
alpha=0.9, |
|
) |
|
ax.set_xlabel(r"Improvement of $r^2$ over baseline at 10$\mu$M") |
|
ax.set_ylabel("Uncertainty") |
|
ax.legend(title=r"Baseline $r^2$", fontsize=18, title_fontsize=22, loc="best") |
|
plt.tight_layout() |
|
|
|
if SAVEFIG: |
|
if BLACK: |
|
plt.savefig(FIGURE_DIR / "uncertainty_drug_embedding_balck.png", format="png") |
|
else: |
|
plt.savefig(FIGURE_DIR / "uncertainty_drug_embedding.png", format="png") |
|
SAVEFIG = False |
|
|
|
|
|
adata.obsm["umap"] = mapper_pretrained_high |
|
|
|
|
|
SAVEFIG = True |
|
|
|
|
|
groups = [ |
|
"Histone deacetylation", |
|
"Tyrosine kinase signaling", |
|
"Cell cycle regulation", |
|
"DNA damage & DNA repair", |
|
] |
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(19, 5), sharex=True) |
|
|
|
plot_umap( |
|
mapper_pretrained_high, |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=groups, |
|
ood_drugs=ood_drugs, |
|
ax=ax[0], |
|
) |
|
|
|
uncertainty = adata.obs["uncertainty"].values |
|
plot_umap( |
|
mapper_pretrained_high, |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=[""], |
|
hue=uncertainty, |
|
ood_drugs=ood_drugs, |
|
ax=ax[1], |
|
) |
|
|
|
|
|
ax[1].set(xticklabels=[], yticklabels=[]) |
|
|
|
ax[1].legend( |
|
title="Uncertainty", |
|
labels=["high", "low"], |
|
loc="lower left", |
|
bbox_to_anchor=(0.95, 0.1), |
|
bbox_transform=fig.transFigure, |
|
) |
|
|
|
ax[0].set(xticklabels=[], yticklabels=[]) |
|
|
|
ax[0].grid(False) |
|
ax[1].grid(False) |
|
|
|
ax[0].legend( |
|
title="Pathway", |
|
fontsize=18, |
|
title_fontsize=22, |
|
loc="upper left", |
|
bbox_to_anchor=(0.95, 0.9), |
|
bbox_transform=fig.transFigure, |
|
) |
|
|
|
|
|
plt.subplots_adjust(wspace=0.05, hspace=0.05) |
|
|
|
fig.text( |
|
0.5, 0.0, f"UMAP of $z_d$ for a dosage of $10\,\mu$M", ha="center", va="center" |
|
) |
|
|
|
if SAVEFIG: |
|
if BLACK: |
|
plt.savefig( |
|
FIGURE_DIR / "umap_drug_embedding_uncertainty_black.png", format="png" |
|
) |
|
else: |
|
plt.savefig( |
|
FIGURE_DIR / "umap_drug_embedding_uncertainty.png", |
|
format="png", |
|
bbox_inches="tight", |
|
) |
|
|
|
|
|
|