|
import pandas as pd |
|
import torch.testing |
|
|
|
from chemCPA.data.data import Dataset |
|
from chemCPA.embedding import get_chemical_representation |
|
from chemCPA.model import ComPert |
|
|
|
|
|
def test_embedding_idx_roundtrip(): |
|
|
|
|
|
kwargs = { |
|
"perturbation_key": "condition", |
|
"pert_category": "cov_drug_dose_name", |
|
"dose_key": "dose", |
|
"covariate_keys": "cell_type", |
|
"smiles_key": "SMILES", |
|
"split_key": "split", |
|
} |
|
|
|
|
|
control_emb = torch.tensor( |
|
pd.read_parquet("embeddings/grover/data/embeddings/grover_base.parquet") |
|
.loc["CS(C)=O"] |
|
.values |
|
) |
|
|
|
for use_drugs_idx in [True, False]: |
|
dataset = Dataset( |
|
fname="datasets/trapnell_cpa_subset.h5ad", |
|
**kwargs, |
|
use_drugs_idx=use_drugs_idx |
|
) |
|
embedding = get_chemical_representation( |
|
data_dir="embeddings/", |
|
smiles=dataset.canon_smiles_unique_sorted, |
|
embedding_model="grover_base", |
|
) |
|
device = embedding.weight.device |
|
|
|
|
|
control = torch.tensor( |
|
list(dataset.drugs_names_unique_sorted).index("control"), |
|
device=device, |
|
) |
|
torch.testing.assert_close(embedding(control), control_emb.to(device)) |
|
|
|
model = ComPert( |
|
dataset.num_genes, |
|
dataset.num_drugs, |
|
dataset.num_covariates, |
|
device=device, |
|
doser_type="sigm", |
|
drug_embeddings=embedding, |
|
use_drugs_idx=use_drugs_idx, |
|
) |
|
if use_drugs_idx: |
|
genes, idx, dosages, covariates = dataset[:] |
|
idx_emb = model.compute_drug_embeddings_(drugs_idx=idx, dosages=dosages) |
|
else: |
|
genes, drugs, covariates = dataset[:] |
|
ohe_emb = model.compute_drug_embeddings_(drugs=drugs) |
|
|
|
|
|
torch.testing.assert_close(idx_emb, ohe_emb) |
|
|