|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import scanpy as sc |
|
import seaborn as sns |
|
import torch |
|
import umap.plot |
|
from utils import ( |
|
compute_drug_embeddings, |
|
compute_pred, |
|
load_config, |
|
load_dataset, |
|
load_model, |
|
load_smiles, |
|
) |
|
|
|
from chemCPA.paths import FIGURE_DIR, ROOT |
|
|
|
|
|
BLACK = False |
|
SAVEFIG = True |
|
|
|
|
|
if BLACK: |
|
plt.style.use("dark_background") |
|
else: |
|
matplotlib.style.use("fivethirtyeight") |
|
matplotlib.style.use("seaborn-talk") |
|
matplotlib.pyplot.rcParams["savefig.facecolor"] = "white" |
|
sns.set_style("whitegrid") |
|
|
|
matplotlib.rcParams["font.family"] = "monospace" |
|
matplotlib.rcParams["figure.dpi"] = 300 |
|
sns.set_context("poster") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_umap( |
|
mapper, |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=[ |
|
"Epigenetic regulation", |
|
"Tyrosine kinase signaling", |
|
"Cell cycle regulation", |
|
], |
|
ood_drugs=[], |
|
ax=None, |
|
): |
|
|
|
if ax == None: |
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
pathway = [] |
|
drugs = [] |
|
|
|
for s in canon_smiles_unique_sorted: |
|
if s in smiles_to_pathway_map: |
|
pathway.append(smiles_to_pathway_map[s]) |
|
drugs.append(smiles_to_drug_map[s]) |
|
else: |
|
pathway.append("other") |
|
drugs.append("unknown") |
|
|
|
pathway = np.where(pd.Series(pathway).isin(groups), pathway, "other") |
|
|
|
|
|
|
|
|
|
|
|
|
|
cond = pathway != "other" |
|
sns.scatterplot( |
|
x=mapper[cond, 0], |
|
y=mapper[cond, 1], |
|
hue=pathway[cond], |
|
ax=ax, |
|
palette="tab10" if BLACK else None, |
|
) |
|
sns.scatterplot( |
|
x=mapper[~cond, 0], y=mapper[~cond, 1], ax=ax, color="grey", alpha=0.3 |
|
) |
|
|
|
shift = 0.05 |
|
for i, label in enumerate(drugs): |
|
if drugs[i] in ood_drugs: |
|
ax.text( |
|
x=mapper[i, 0] + shift, |
|
y=mapper[i, 1] + shift, |
|
s=label, |
|
fontdict=dict( |
|
color="white" if BLACK else "black", alpha=1, size=12, weight=600 |
|
), |
|
bbox=dict(facecolor="black" if BLACK else "lightgrey", alpha=0.3), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seml_collection = "multi_task" |
|
|
|
model_hash_pretrained_rdkit = "c824e42f7ce751cf9a8ed26f0d9e0af7" |
|
model_hash_scratch_rdkit = "59bdaefb1c1adfaf2976e3fdf62afa21" |
|
|
|
model_hash_pretrained_grover = "c30016a7469feb78a8ee9ebb18ed9b1f" |
|
model_hash_scratch_grover = "60e4b40e8d67bff2d5efc5e22e265820" |
|
|
|
model_hash_pretrained_jtvae = "915345a522c29fa709b995d6149083b9" |
|
model_hash_scratch_jtvae = "934c89b742a6309ad6bb2e1cf90c5e50" |
|
|
|
|
|
model_hash_pretrained = model_hash_pretrained_rdkit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = load_config(seml_collection, model_hash_pretrained) |
|
config["dataset"]["data_params"]["dataset_path"] = ( |
|
ROOT / config["dataset"]["data_params"]["dataset_path"] |
|
) |
|
config["model"]["embedding"]["directory"] = ( |
|
ROOT / config["model"]["embedding"]["directory"] |
|
) |
|
dataset, key_dict = load_dataset(config) |
|
config["dataset"]["n_vars"] = dataset.n_vars |
|
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles( |
|
config, dataset, key_dict, True |
|
) |
|
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted) |
|
|
|
|
|
dataset |
|
|
|
|
|
config |
|
|
|
|
|
|
|
|
|
|
|
ood_drugs = ( |
|
dataset.obs.condition[dataset.obs.split_ood_multi_task.isin(["ood"])] |
|
.unique() |
|
.to_list() |
|
) |
|
|
|
|
|
ood_drugs |
|
|
|
|
|
|
|
|
|
|
|
smiles_to_pw_level2_map = {} |
|
pw1_to_pw2 = {} |
|
|
|
for (drug, pw1, pw2), df in dataset.obs.groupby( |
|
["SMILES", "pathway_level_1", "pathway_level_2"] |
|
): |
|
smiles_to_pw_level2_map[drug] = pw2 |
|
if pw1 in pw1_to_pw2: |
|
pw1_to_pw2[pw1].add(pw2) |
|
else: |
|
pw1_to_pw2[pw1] = {pw2} |
|
|
|
|
|
for s, pw in smiles_to_pw_level2_map.items(): |
|
if pw == "Histone deacetylation": |
|
smiles_to_pathway_map[s] = pw |
|
|
|
|
|
groups = [ |
|
"Histone deacetylation", |
|
"Tyrosine kinase signaling", |
|
"Cell cycle regulation", |
|
"DNA damage & DNA repair", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transf_embeddings_pretrained_high = compute_drug_embeddings( |
|
model_pretrained, embedding_pretrained, dosage=1e4 |
|
) |
|
mapper_pretrained_high = umap.UMAP( |
|
n_neighbors=25, min_dist=1, spread=2, metric="euclidean" |
|
).fit_transform(transf_embeddings_pretrained_high) |
|
|
|
transf_embeddings_pretrained_low = compute_drug_embeddings( |
|
model_pretrained, embedding_pretrained, dosage=10 |
|
) |
|
mapper_pretrained_low = umap.UMAP(n_neighbors=25, min_dist=1, spread=2).fit_transform( |
|
transf_embeddings_pretrained_low |
|
) |
|
|
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(21, 5)) |
|
|
|
plot_umap( |
|
mapper_pretrained_high, |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=groups, |
|
ood_drugs=ood_drugs, |
|
ax=ax[1], |
|
) |
|
|
|
plot_umap( |
|
mapper_pretrained_low, |
|
canon_smiles_unique_sorted, |
|
smiles_to_pathway_map, |
|
smiles_to_drug_map, |
|
groups=groups, |
|
ood_drugs=ood_drugs, |
|
ax=ax[0], |
|
) |
|
ax[0].set(xticklabels=[], yticklabels=[]) |
|
ax[0].set_xlabel(f"UMAP of $z_d$ for a dosage of $10\,$nM") |
|
|
|
ax[1].set(xticklabels=[], yticklabels=[]) |
|
ax[1].set_xlabel(f"UMAP of $z_d$ for a dosage of $10\,\mu$M") |
|
|
|
ax[0].grid(False) |
|
ax[1].grid(False) |
|
ax[0].get_legend().remove() |
|
ax[1].legend( |
|
title="Pathway", |
|
fontsize=18, |
|
title_fontsize=22, |
|
loc="upper left", |
|
bbox_to_anchor=(1, 1), |
|
) |
|
plt.tight_layout() |
|
|
|
if SAVEFIG: |
|
if BLACK: |
|
plt.savefig(FIGURE_DIR / "umap_drug_embedding_black.png", format="png") |
|
else: |
|
plt.savefig(FIGURE_DIR / "umap_drug_embedding.png", format="png") |
|
|
|
|
|
|
|
|
|
|
|
|