|
import numpy as np |
|
import pandas as pd |
|
import scanpy as sc |
|
import seml |
|
import torch |
|
from tqdm.auto import tqdm |
|
|
|
from chemCPA.data import ( |
|
SubDataset, |
|
canonicalize_smiles, |
|
drug_names_to_once_canon_smiles, |
|
) |
|
from chemCPA.embedding import get_chemical_representation |
|
from chemCPA.model import ComPert |
|
from chemCPA.paths import CHECKPOINT_DIR |
|
from chemCPA.train import bool2idx, compute_prediction, compute_r2, repeat_n |
|
|
|
|
|
def load_config(seml_collection, model_hash): |
|
results_df = seml.get_results( |
|
seml_collection, |
|
to_data_frame=True, |
|
fields=["config", "config_hash"], |
|
states=["COMPLETED"], |
|
filter_dict={"config_hash": model_hash}, |
|
) |
|
experiment = results_df.apply( |
|
lambda exp: { |
|
"hash": exp["config_hash"], |
|
"seed": exp["config.seed"], |
|
"_id": exp["_id"], |
|
}, |
|
axis=1, |
|
) |
|
assert len(experiment) == 1 |
|
experiment = experiment[0] |
|
collection = seml.database.get_collection(seml_collection) |
|
config = collection.find_one({"_id": experiment["_id"]})["config"] |
|
assert config["dataset"]["data_params"]["use_drugs_idx"] |
|
assert config["model"]["additional_params"]["doser_type"] == "amortized" |
|
config["config_hash"] = model_hash |
|
return config |
|
|
|
|
|
def load_dataset(config): |
|
perturbation_key = config["dataset"]["data_params"]["perturbation_key"] |
|
smiles_key = config["dataset"]["data_params"]["smiles_key"] |
|
dataset = sc.read(config["dataset"]["data_params"]["dataset_path"]) |
|
key_dict = { |
|
"perturbation_key": perturbation_key, |
|
"smiles_key": smiles_key, |
|
} |
|
return dataset, key_dict |
|
|
|
|
|
def load_smiles(config, dataset, key_dict, return_pathway_map=False): |
|
perturbation_key = key_dict["perturbation_key"] |
|
smiles_key = key_dict["smiles_key"] |
|
|
|
|
|
|
|
|
|
drugs_names = np.array(dataset.obs[perturbation_key].values) |
|
drugs_names_unique = set() |
|
for d in drugs_names: |
|
[drugs_names_unique.add(i) for i in d.split("+")] |
|
drugs_names_unique_sorted = np.array(sorted(drugs_names_unique)) |
|
canon_smiles_unique_sorted = drug_names_to_once_canon_smiles( |
|
list(drugs_names_unique_sorted), dataset, perturbation_key, smiles_key |
|
) |
|
|
|
smiles_to_drug_map = { |
|
canonicalize_smiles(smiles): drug |
|
for smiles, drug in dataset.obs.groupby( |
|
[config["dataset"]["data_params"]["smiles_key"], perturbation_key] |
|
).groups.keys() |
|
} |
|
if return_pathway_map: |
|
smiles_to_pathway_map = { |
|
canonicalize_smiles(smiles): pathway |
|
for smiles, pathway in dataset.obs.groupby( |
|
[config["dataset"]["data_params"]["smiles_key"], "pathway_level_1"] |
|
).groups.keys() |
|
} |
|
return canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map |
|
return canon_smiles_unique_sorted, smiles_to_drug_map |
|
|
|
|
|
def load_model(config, canon_smiles_unique_sorted): |
|
model_hash = config["config_hash"] |
|
model_checkp = CHECKPOINT_DIR / (model_hash + ".pt") |
|
|
|
embedding_model = config["model"]["embedding"]["model"] |
|
if embedding_model == "vanilla": |
|
embedding = None |
|
else: |
|
embedding = get_chemical_representation( |
|
smiles=canon_smiles_unique_sorted, |
|
embedding_model=config["model"]["embedding"]["model"], |
|
data_path=config["model"]["embedding"]["directory"], |
|
device="cuda", |
|
) |
|
dumped_model = torch.load(model_checkp) |
|
if len(dumped_model) == 3: |
|
print("This model does not contain the covariate embeddings or adversaries.") |
|
state_dict, init_args, history = dumped_model |
|
COV_EMB_AVAILABLE = False |
|
elif len(dumped_model) == 4: |
|
print("This model does not contain the covariate embeddings.") |
|
state_dict, cov_adv_state_dicts, init_args, history = dumped_model |
|
COV_EMB_AVAILABLE = False |
|
elif len(dumped_model) == 5: |
|
( |
|
state_dict, |
|
cov_adv_state_dicts, |
|
cov_emb_state_dicts, |
|
init_args, |
|
history, |
|
) = dumped_model |
|
COV_EMB_AVAILABLE = True |
|
assert len(cov_emb_state_dicts) == 1 |
|
append_layer_width = ( |
|
config["dataset"]["n_vars"] |
|
if (config["model"]["append_ae_layer"] and config["model"]["load_pretrained"]) |
|
else None |
|
) |
|
|
|
if embedding_model != "vanilla": |
|
state_dict.pop("drug_embeddings.weight") |
|
model = ComPert(**init_args, drug_embeddings=embedding, append_layer_width=append_layer_width) |
|
model = model.eval() |
|
if COV_EMB_AVAILABLE: |
|
for embedding_cov, state_dict_cov in zip(model.covariates_embeddings, cov_emb_state_dicts): |
|
embedding_cov.load_state_dict(state_dict_cov) |
|
|
|
incomp_keys = model.load_state_dict(state_dict, strict=False) |
|
if embedding_model == "vanilla": |
|
assert len(incomp_keys.unexpected_keys) == 0 and len(incomp_keys.missing_keys) == 0 |
|
else: |
|
|
|
torch.testing.assert_allclose(model.drug_embeddings.weight, embedding.weight) |
|
assert ( |
|
len(incomp_keys.missing_keys) == 1 and "drug_embeddings.weight" in incomp_keys.missing_keys |
|
), incomp_keys.missing_keys |
|
|
|
|
|
return model, embedding |
|
|
|
|
|
def compute_drug_embeddings(model, embedding, dosage=1e4): |
|
all_drugs_idx = torch.tensor(list(range(len(embedding.weight)))) |
|
dosages = dosage * torch.ones((len(embedding.weight),)) |
|
|
|
with torch.no_grad(): |
|
|
|
transf_embeddings = model.compute_drug_embeddings_(drugs_idx=all_drugs_idx, dosages=dosages) |
|
|
|
|
|
return transf_embeddings |
|
|
|
|
|
def compute_pred( |
|
model, |
|
dataset, |
|
dosages=[1e4], |
|
cell_lines=None, |
|
genes_control=None, |
|
use_DEGs=True, |
|
verbose=True, |
|
): |
|
|
|
pert_categories_index = pd.Index(dataset.pert_categories, dtype="category") |
|
|
|
allowed_cell_lines = [] |
|
|
|
cl_dict = { |
|
torch.Tensor([1, 0, 0]): "A549", |
|
torch.Tensor([0, 1, 0]): "K562", |
|
torch.Tensor([0, 0, 1]): "MCF7", |
|
} |
|
|
|
if cell_lines is None: |
|
cell_lines = ["A549", "K562", "MCF7"] |
|
|
|
print(cell_lines) |
|
|
|
predictions_dict = {} |
|
drug_r2 = {} |
|
for cell_drug_dose_comb, category_count in tqdm(zip(*np.unique(dataset.pert_categories, return_counts=True))): |
|
if dataset.perturbation_key is None: |
|
break |
|
|
|
|
|
if category_count <= 5: |
|
continue |
|
|
|
|
|
if "dmso" in cell_drug_dose_comb.lower() or "control" in cell_drug_dose_comb.lower(): |
|
continue |
|
|
|
|
|
|
|
|
|
bool_de = dataset.var_names.isin(np.array(dataset.de_genes[cell_drug_dose_comb])) |
|
idx_de = bool2idx(bool_de) |
|
|
|
|
|
if len(idx_de) < 2: |
|
continue |
|
|
|
bool_category = pert_categories_index.get_loc(cell_drug_dose_comb) |
|
idx_all = bool2idx(bool_category) |
|
idx = idx_all[0] |
|
y_true = dataset.genes[idx_all, :].to(device="cuda") |
|
|
|
|
|
|
|
|
|
|
|
if genes_control is None: |
|
n_obs = y_true.size(0) |
|
else: |
|
assert isinstance(genes_control, torch.Tensor) |
|
n_obs = genes_control.size(0) |
|
|
|
emb_covs = [repeat_n(cov[idx], n_obs) for cov in dataset.covariates] |
|
|
|
if dataset.dosages[idx] not in dosages: |
|
continue |
|
|
|
stop = False |
|
for tensor, cl in cl_dict.items(): |
|
if (tensor == dataset.covariates[0][idx]).all(): |
|
if cl not in cell_lines: |
|
stop = True |
|
if stop: |
|
continue |
|
|
|
if dataset.use_drugs_idx: |
|
emb_drugs = ( |
|
repeat_n(dataset.drugs_idx[idx], n_obs).squeeze(), |
|
repeat_n(dataset.dosages[idx], n_obs).squeeze(), |
|
) |
|
else: |
|
emb_drugs = repeat_n(dataset.drugs[idx], n_obs) |
|
|
|
|
|
|
|
|
|
if genes_control is None: |
|
|
|
mean_pred, _ = compute_prediction( |
|
model, |
|
y_true, |
|
emb_drugs, |
|
emb_covs, |
|
) |
|
else: |
|
|
|
mean_pred, _ = compute_prediction( |
|
model, |
|
genes_control, |
|
emb_drugs, |
|
emb_covs, |
|
) |
|
|
|
y_pred = mean_pred |
|
_y_pred = mean_pred.mean(0) |
|
_y_true = y_true.mean(0) |
|
if use_DEGs: |
|
r2_m_de = compute_r2(_y_true[idx_de].cuda(), _y_pred[idx_de].cuda()) |
|
print(f"{cell_drug_dose_comb}: {r2_m_de:.2f}") if verbose else None |
|
drug_r2[cell_drug_dose_comb] = max(r2_m_de, 0.0) |
|
else: |
|
r2_m = compute_r2(_y_true.cuda(), _y_pred.cuda()) |
|
print(f"{cell_drug_dose_comb}: {r2_m:.2f}") if verbose else None |
|
drug_r2[cell_drug_dose_comb] = max(r2_m, 0.0) |
|
|
|
|
|
predictions_dict[cell_drug_dose_comb] = [ |
|
genes_control.detach().cpu().numpy(), |
|
y_pred.detach().cpu().numpy(), |
|
y_true.detach().cpu().numpy(), |
|
] |
|
return drug_r2, predictions_dict |
|
|
|
|
|
def compute_pred_ctrl( |
|
dataset, |
|
dosages=[1e4], |
|
cell_lines=None, |
|
dataset_ctrl=None, |
|
use_DEGs=True, |
|
verbose=True, |
|
): |
|
|
|
pert_categories_index = pd.Index(dataset.pert_categories, dtype="category") |
|
|
|
allowed_cell_lines = [] |
|
|
|
cl_dict = { |
|
torch.Tensor([1, 0, 0]): "A549", |
|
torch.Tensor([0, 1, 0]): "K562", |
|
torch.Tensor([0, 0, 1]): "MCF7", |
|
} |
|
|
|
if cell_lines is None: |
|
cell_lines = ["A549", "K562", "MCF7"] |
|
|
|
print(cell_lines) |
|
|
|
predictions_dict = {} |
|
drug_r2 = {} |
|
for cell_drug_dose_comb, category_count in tqdm(zip(*np.unique(dataset.pert_categories, return_counts=True))): |
|
if dataset.perturbation_key is None: |
|
break |
|
|
|
|
|
if category_count <= 5: |
|
continue |
|
|
|
|
|
if "dmso" in cell_drug_dose_comb.lower() or "control" in cell_drug_dose_comb.lower(): |
|
continue |
|
|
|
|
|
|
|
|
|
bool_de = dataset.var_names.isin(np.array(dataset.de_genes[cell_drug_dose_comb])) |
|
idx_de = bool2idx(bool_de) |
|
|
|
|
|
if len(idx_de) < 2: |
|
continue |
|
|
|
bool_category = pert_categories_index.get_loc(cell_drug_dose_comb) |
|
idx_all = bool2idx(bool_category) |
|
idx = idx_all[0] |
|
y_true = dataset.genes[idx_all, :].to(device="cuda") |
|
|
|
cov_name = cell_drug_dose_comb.split("_")[0] |
|
cond = dataset_ctrl.covariate_names["cell_type"] == cov_name |
|
genes_control = dataset_ctrl.genes[cond] |
|
|
|
if genes_control is None: |
|
n_obs = y_true.size(0) |
|
else: |
|
assert isinstance(genes_control, torch.Tensor) |
|
n_obs = genes_control.size(0) |
|
|
|
emb_covs = [repeat_n(cov[idx], n_obs) for cov in dataset.covariates] |
|
|
|
if dataset.dosages[idx] not in dosages: |
|
continue |
|
|
|
stop = False |
|
for tensor, cl in cl_dict.items(): |
|
if (tensor == dataset.covariates[0][idx]).all(): |
|
if cl not in cell_lines: |
|
stop = True |
|
if stop: |
|
continue |
|
|
|
if dataset.use_drugs_idx: |
|
emb_drugs = ( |
|
repeat_n(dataset.drugs_idx[idx], n_obs).squeeze(), |
|
repeat_n(dataset.dosages[idx], n_obs).squeeze(), |
|
) |
|
else: |
|
emb_drugs = repeat_n(dataset.drugs[idx], n_obs) |
|
|
|
y_pred = genes_control |
|
_y_pred = genes_control.mean(0) |
|
_y_true = y_true.mean(0) |
|
if use_DEGs: |
|
r2_m_de = compute_r2(_y_true[idx_de].cuda(), _y_pred[idx_de].cuda()) |
|
print(f"{cell_drug_dose_comb}: {r2_m_de:.2f}") if verbose else None |
|
drug_r2[cell_drug_dose_comb] = max(r2_m_de, 0.0) |
|
else: |
|
r2_m = compute_r2(_y_true.cuda(), _y_pred.cuda()) |
|
print(f"{cell_drug_dose_comb}: {r2_m:.2f}") if verbose else None |
|
drug_r2[cell_drug_dose_comb] = max(r2_m, 0.0) |
|
|
|
|
|
predictions_dict[cell_drug_dose_comb] = [ |
|
genes_control.detach().cpu().numpy(), |
|
y_pred.detach().cpu().numpy(), |
|
y_true.detach().cpu().numpy(), |
|
] |
|
return drug_r2, predictions_dict |
|
|
|
|
|
def evaluate_r2(autoencoder: ComPert, dataset: SubDataset, genes_control: torch.Tensor): |
|
""" |
|
Measures different quality metrics about an ComPert `autoencoder`, when |
|
tasked to translate some `genes_control` into each of the drug/covariates |
|
combinations described in `dataset`. |
|
|
|
Considered metrics are R2 score about means and variances for all genes, as |
|
well as R2 score about means and variances about differentially expressed |
|
(_de) genes. |
|
""" |
|
mean_score, var_score, mean_score_de, var_score_de = [], [], [], [] |
|
n_rows = genes_control.size(0) |
|
|
|
|
|
|
|
pert_categories_index = pd.Index(dataset.pert_categories, dtype="category") |
|
for cell_drug_dose_comb, category_count in zip(*np.unique(dataset.pert_categories, return_counts=True)): |
|
if dataset.perturbation_key is None: |
|
break |
|
|
|
|
|
if category_count <= 5: |
|
continue |
|
|
|
|
|
if "dmso" in cell_drug_dose_comb.lower() or "control" in cell_drug_dose_comb.lower(): |
|
continue |
|
|
|
|
|
|
|
|
|
bool_de = dataset.var_names.isin(np.array(dataset.de_genes[cell_drug_dose_comb])) |
|
idx_de = bool2idx(bool_de) |
|
|
|
|
|
if len(idx_de) < 2: |
|
continue |
|
|
|
bool_category = pert_categories_index.get_loc(cell_drug_dose_comb) |
|
idx_all = bool2idx(bool_category) |
|
idx = idx_all[0] |
|
|
|
emb_covs = [repeat_n(cov[idx], n_rows) for cov in dataset.covariates] |
|
if dataset.use_drugs_idx: |
|
emb_drugs = ( |
|
repeat_n(dataset.drugs_idx[idx], n_rows).squeeze(), |
|
repeat_n(dataset.dosages[idx], n_rows).squeeze(), |
|
) |
|
else: |
|
emb_drugs = repeat_n(dataset.drugs[idx], n_rows) |
|
mean_pred, var_pred = compute_prediction( |
|
autoencoder, |
|
genes_control, |
|
emb_drugs, |
|
emb_covs, |
|
) |
|
|
|
|
|
|
|
y_true = dataset.genes[idx_all, :].to(device="cuda") |
|
|
|
|
|
yt_m = y_true.mean(dim=0) |
|
yt_v = y_true.var(dim=0) |
|
|
|
yp_m = mean_pred.mean(dim=0).to(device="cuda") |
|
yp_v = var_pred.mean(dim=0).to(device="cuda") |
|
|
|
r2_m = compute_r2(yt_m, yp_m) |
|
r2_v = compute_r2(yt_v, yp_v) |
|
r2_m_de = compute_r2(yt_m[idx_de], yp_m[idx_de]) |
|
r2_v_de = compute_r2(yt_v[idx_de], yp_v[idx_de]) |
|
|
|
|
|
if r2_m_de == float("-inf") or r2_v_de == float("-inf"): |
|
continue |
|
|
|
mean_score.append(r2_m) |
|
var_score.append(r2_v) |
|
mean_score_de.append(r2_m_de) |
|
var_score_de.append(r2_v_de) |
|
print(f"Number of different r2 computations: {len(mean_score)}") |
|
if len(mean_score) > 0: |
|
return [np.mean(s) for s in [mean_score, mean_score_de, var_score, var_score_de]] |
|
else: |
|
return [] |
|
|