Spaces:
Build error
Build error
from unittest.mock import MagicMock, call | |
import numpy as np | |
import pytest | |
from rdkit import Chem | |
from sklearn.preprocessing import StandardScaler | |
from chemprop.data.datasets import MoleculeDatapoint, MoleculeDataset | |
from chemprop.data.molgraph import MolGraph | |
from chemprop.featurizers.molgraph import SimpleMoleculeMolGraphFeaturizer | |
def smis(smis, request): | |
return smis.sample(request.param).to_list() | |
def targets(smis): | |
return np.random.rand(len(smis), 1) | |
def mols(smis): | |
return [Chem.MolFromSmiles(smi) for smi in smis] | |
def X_d(mols): | |
return [np.random.rand(1) for _ in mols] | |
def V_fs(mols): | |
return [np.random.rand(mol.GetNumAtoms(), 1) for mol in mols] | |
def E_fs(mols): | |
return [np.random.rand(mol.GetNumBonds(), 2) for mol in mols] | |
def V_ds(mols): | |
return [np.random.rand(mol.GetNumAtoms(), 3) for mol in mols] | |
def data(mols, targets, X_d, V_fs, E_fs, V_ds): | |
return [ | |
MoleculeDatapoint(mol=mol, y=target, x_d=x_d, V_f=V_f, E_f=E_f, V_d=V_d) | |
for mol, target, x_d, V_f, E_f, V_d in zip(mols, targets, X_d, V_fs, E_fs, V_ds) | |
] | |
def cache(request): | |
return request.param | |
def dataset(data, cache): | |
extra_atom_fdim = data[0].V_f.shape[1] if data[0].V_f is not None else 0 | |
extra_bond_fdim = data[0].E_f.shape[1] if data[0].E_f is not None else 0 | |
dset = MoleculeDataset( | |
data, | |
SimpleMoleculeMolGraphFeaturizer( | |
extra_atom_fdim=extra_atom_fdim, extra_bond_fdim=extra_bond_fdim | |
), | |
) | |
dset.cache = cache | |
return dset | |
def test_none(): | |
with pytest.raises(ValueError): | |
MoleculeDataset(None, SimpleMoleculeMolGraphFeaturizer()) | |
def test_empty(): | |
"""TODO""" | |
def test_len(data, dataset): | |
assert len(data) == len(dataset) | |
def test_smis(dataset, smis): | |
assert smis == dataset.smiles | |
def test_targets(dataset, targets): | |
np.testing.assert_array_equal(dataset.Y, targets) | |
def test_set_targets_too_short(dataset): | |
with pytest.raises(ValueError): | |
dataset.Y = np.random.rand(len(dataset) // 2, 1) | |
def test_num_tasks(dataset, targets): | |
assert dataset.t == targets.shape[1] | |
def test_aux_nones(dataset: MoleculeDataset): | |
np.testing.assert_array_equal(dataset.X_d, None) | |
np.testing.assert_array_equal(dataset.V_fs, None) | |
np.testing.assert_array_equal(dataset.E_fs, None) | |
np.testing.assert_array_equal(dataset.V_ds, None) | |
np.testing.assert_array_equal(dataset.gt_mask, None) | |
np.testing.assert_array_equal(dataset.lt_mask, None) | |
assert dataset.d_xd == 0 | |
assert dataset.d_vf == 0 | |
assert dataset.d_ef == 0 | |
assert dataset.d_vd == 0 | |
def test_normalize_targets(dataset): | |
dset_scaler = dataset.normalize_targets() | |
scaler = StandardScaler() | |
scaler.fit(dataset._Y) | |
Y = scaler.transform(dataset._Y) | |
np.testing.assert_array_equal(dataset.Y, Y) | |
np.testing.assert_array_equal(dset_scaler.mean_, scaler.mean_) | |
np.testing.assert_array_equal(dset_scaler.scale_, scaler.scale_) | |
def test_normalize_inputs(dataset): | |
dset_scaler = dataset.normalize_inputs("X_d") | |
scaler = StandardScaler() | |
scaler.fit(dataset._X_d) | |
X = scaler.transform(dataset._X_d) | |
np.testing.assert_array_equal(dataset.X_d, X) | |
np.testing.assert_array_equal(dset_scaler.mean_, scaler.mean_) | |
np.testing.assert_array_equal(dset_scaler.scale_, scaler.scale_) | |
inputs = ["V_f", "E_f", "V_d"] | |
for input_ in inputs: | |
dset_scaler = dataset.normalize_inputs(input_) | |
scaler = StandardScaler() | |
Xs = getattr(dataset, f"_{input_}s") | |
X = np.concatenate(Xs, axis=0) | |
scaler.fit(X) | |
Xs = [scaler.transform(x) for x in Xs] | |
for X, dset_X in zip(Xs, getattr(dataset, f"{input_}s")): | |
np.testing.assert_array_equal(X, dset_X) | |
np.testing.assert_array_equal(getattr(dset_scaler, "mean_"), scaler.mean_) | |
np.testing.assert_array_equal(getattr(dset_scaler, "scale_"), scaler.scale_) | |
def test_cache(dataset: MoleculeDataset, cache): | |
"""Test that cache attribute is being set appropriately and that the underlying cache is being | |
used correctly to load the molgraphs.""" | |
mg = MolGraph(None, None, None, None) | |
dataset.cache = cache | |
assert dataset.cache == cache | |
dataset.mg_cache = MagicMock() | |
dataset.mg_cache.__getitem__.side_effect = lambda i: mg | |
calls = [] | |
for i in range(len(dataset)): | |
assert dataset[i].mg is mg | |
calls.append(call(i)) | |
dataset.mg_cache.__getitem__.assert_has_calls(calls) | |