|
import numpy |
|
import torch.testing |
|
|
|
from chemCPA.data.data import Dataset |
|
|
|
|
|
def test_dataset_idx_ohe(): |
|
kwargs = { |
|
"perturbation_key": "condition", |
|
"pert_category": "cov_drug_dose_name", |
|
"dose_key": "dose", |
|
"covariate_keys": "cell_type", |
|
"smiles_key": "SMILES", |
|
"split_key": "split", |
|
} |
|
d_idx = Dataset( |
|
fname="datasets/trapnell_cpa_subset.h5ad", |
|
**kwargs, |
|
use_drugs_idx=True, |
|
) |
|
|
|
d_ohe = Dataset( |
|
fname="datasets/trapnell_cpa_subset.h5ad", |
|
**kwargs, |
|
use_drugs_idx=False, |
|
) |
|
|
|
numpy.testing.assert_equal( |
|
d_ohe.encoder_drug.categories_[0], d_idx.drugs_names_unique_sorted |
|
) |
|
|
|
for i in range(len(d_idx)): |
|
genes_idx, idx, dosage, cov_idx = d_idx[i] |
|
genes_ohe, drug, cov_ohe = d_ohe[i] |
|
torch.testing.assert_close(genes_idx, genes_ohe) |
|
|
|
torch.testing.assert_close(drug[idx], dosage) |
|
torch.testing.assert_close(cov_idx, cov_ohe) |
|
|