File size: 1,075 Bytes
a48f0ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy
import torch.testing

from chemCPA.data.data import Dataset


def test_dataset_idx_ohe():
    kwargs = {
        "perturbation_key": "condition",
        "pert_category": "cov_drug_dose_name",
        "dose_key": "dose",
        "covariate_keys": "cell_type",
        "smiles_key": "SMILES",
        "split_key": "split",
    }
    d_idx = Dataset(
        fname="datasets/trapnell_cpa_subset.h5ad",
        **kwargs,
        use_drugs_idx=True,
    )

    d_ohe = Dataset(
        fname="datasets/trapnell_cpa_subset.h5ad",
        **kwargs,
        use_drugs_idx=False,
    )

    numpy.testing.assert_equal(
        d_ohe.encoder_drug.categories_[0], d_idx.drugs_names_unique_sorted
    )

    for i in range(len(d_idx)):
        genes_idx, idx, dosage, cov_idx = d_idx[i]
        genes_ohe, drug, cov_ohe = d_ohe[i]
        torch.testing.assert_close(genes_idx, genes_ohe)
        # make sure the OHE and the index representation encode the same info
        torch.testing.assert_close(drug[idx], dosage)
        torch.testing.assert_close(cov_idx, cov_ohe)