File size: 3,019 Bytes
3ea26d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import pytest

from chemprop.data import ClassBalanceSampler, MoleculeDatapoint, MoleculeDataset, SeededSampler
from chemprop.featurizers.molgraph import SimpleMoleculeMolGraphFeaturizer


@pytest.fixture(params=[0.0, 0.1, 0.5, 1.0])
def threshold(request):
    return request.param


@pytest.fixture
def bin_targets(targets, threshold):
    return targets <= threshold


@pytest.fixture
def featurizer():
    return SimpleMoleculeMolGraphFeaturizer()


@pytest.fixture
def dataset(mols, targets, featurizer):
    data = [MoleculeDatapoint(mol, y) for mol, y in zip(mols, targets)]

    return MoleculeDataset(data, featurizer)


@pytest.fixture(params=[0, 24, 100])
def seed(request):
    return request.param


@pytest.fixture
def class_sampler(mols, bin_targets, featurizer):
    data = [MoleculeDatapoint(mol, y) for mol, y in zip(mols, bin_targets)]
    dset = MoleculeDataset(data, featurizer)

    return ClassBalanceSampler(dset.Y, shuffle=True)


def test_seeded_no_seed(dataset):
    with pytest.raises(ValueError):
        SeededSampler(len(dataset), None)


def test_seeded_shuffle(dataset, seed):
    sampler = SeededSampler(len(dataset), seed)

    assert list(sampler) != list(sampler)


def test_seeded_fixed_shuffle(dataset, seed):
    sampler1 = SeededSampler(len(dataset), seed)
    sampler2 = SeededSampler(len(dataset), seed)

    idxs1 = list(sampler1)
    idxs2 = list(sampler2)

    assert idxs1 == idxs2


def test_class_balance_length(class_sampler, bin_targets: np.ndarray):
    n_actives = bin_targets.any(1).sum(0)
    n_inactives = len(bin_targets) - n_actives
    expected_length = 2 * min(n_actives, n_inactives)

    assert len(class_sampler) == expected_length


def test_class_balance_sample(class_sampler, bin_targets: np.ndarray):
    idxs = list(class_sampler)

    # sampled indices should be 50/50 actives/inacitves
    assert sum(bin_targets[idxs]) == len(idxs) // 2


def test_class_balance_shuffle(class_sampler):
    idxs1 = list(class_sampler)
    idxs2 = list(class_sampler)

    if len(class_sampler) == 0:
        pytest.skip("no indices to sample!")

    assert idxs1 != idxs2


def test_seed_class_balance_shuffle(smis, bin_targets, featurizer, seed):
    data = [MoleculeDatapoint.from_smi(smi, target) for smi, target in zip(smis, bin_targets)]
    dset = MoleculeDataset(data, featurizer)

    sampler = ClassBalanceSampler(dset.Y, seed, True)

    if len(sampler) == 0:
        pytest.skip("no indices to sample!")

    assert list(sampler) != list(sampler)


def test_seed_class_balance_reproducibility(smis, bin_targets, featurizer, seed):
    data = [MoleculeDatapoint.from_smi(smi, target) for smi, target in zip(smis, bin_targets)]
    dset = MoleculeDataset(data, featurizer)

    sampler1 = ClassBalanceSampler(dset.Y, seed, True)
    sampler2 = ClassBalanceSampler(dset.Y, seed, True)

    assert list(sampler1) == list(sampler2)