File size: 2,228 Bytes
a48f0ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas as pd
import torch.testing

from chemCPA.data.data import Dataset
from chemCPA.embedding import get_chemical_representation
from chemCPA.model import ComPert


def test_embedding_idx_roundtrip():
    # test to make sure that the same drug embeddings are computed for all drugs
    # in trapnell_subset, independent of whether we use indices or one-hot-encodings
    kwargs = {
        "perturbation_key": "condition",
        "pert_category": "cov_drug_dose_name",
        "dose_key": "dose",
        "covariate_keys": "cell_type",
        "smiles_key": "SMILES",
        "split_key": "split",
    }

    # load the embedding of DSMO
    control_emb = torch.tensor(
        pd.read_parquet("embeddings/grover/data/embeddings/grover_base.parquet")
        .loc["CS(C)=O"]
        .values
    )

    for use_drugs_idx in [True, False]:
        dataset = Dataset(
            fname="datasets/trapnell_cpa_subset.h5ad",
            **kwargs,
            use_drugs_idx=use_drugs_idx
        )
        embedding = get_chemical_representation(
            data_dir="embeddings/",
            smiles=dataset.canon_smiles_unique_sorted,
            embedding_model="grover_base",
        )
        device = embedding.weight.device

        # make sure "control" is correctly encoded as the all zero vector
        control = torch.tensor(
            list(dataset.drugs_names_unique_sorted).index("control"),
            device=device,
        )
        torch.testing.assert_close(embedding(control), control_emb.to(device))

        model = ComPert(
            dataset.num_genes,
            dataset.num_drugs,
            dataset.num_covariates,
            device=device,
            doser_type="sigm",
            drug_embeddings=embedding,
            use_drugs_idx=use_drugs_idx,
        )
        if use_drugs_idx:
            genes, idx, dosages, covariates = dataset[:]
            idx_emb = model.compute_drug_embeddings_(drugs_idx=idx, dosages=dosages)
        else:
            genes, drugs, covariates = dataset[:]
            ohe_emb = model.compute_drug_embeddings_(drugs=drugs)

    # assert both model return the same embedding for the drugs in the dataset
    torch.testing.assert_close(idx_emb, ohe_emb)