File size: 4,265 Bytes
a48f0ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# ---
# jupyter:
#   jupytext:
#     notebook_metadata_filter: -kernelspec
#     text_representation:
#       extension: .py
#       format_name: percent
#       format_version: '1.3'
#       jupytext_version: 1.14.1
# ---

# %% [markdown]
# # JTVAE embedding
# This is a molecule embedding using the JunctionTree VAE, as implemented in DGLLifeSci.
#
# It's pretrained on LINCS + Trapnell + half of ZINC (~220K molecules total).
# LINCS contains a `Cl.[Li]` molecule which fails during encoding, so it just gets a dummy encoding.

# %%
import pickle

import pandas as pd
import rdkit
import torch
from dgllife.data import JTVAECollator, JTVAEDataset
from dgllife.model import load_pretrained
from tqdm import tqdm

print(rdkit.__version__)
print(torch.__version__)
assert torch.cuda.is_available()

# %% pycharm={"name": "#%%\n"}
from dgllife.model import JTNNVAE

from_pretrained = False
if from_pretrained:
    model = load_pretrained("JTVAE_ZINC_no_kl")
else:
    trainfile = "data/train_077a9bedefe77f2a34187eb57be2d416.txt"
    modelfile = "data/model-vaetrain-final.pt"
    vocabfile = "data/vocab-final.pkl"

    with open(vocabfile, "rb") as f:
        vocab = pickle.load(f)

    model = JTNNVAE(vocab=vocab, hidden_size=450, latent_size=56, depth=3)
    model.load_state_dict(torch.load(modelfile, map_location="cpu"))


# %% pycharm={"name": "#%%\n"}
model = model.to("cuda")

# %% pycharm={"name": "#%%\n"}
smiles = pd.read_csv("../lincs_trapnell.smiles")
# need to remove the header, before passing it to JTVAE
smiles.to_csv("jtvae_dataset.smiles", index=False, header=None)

# %% pycharm={"name": "#%%\n"}
dataset = JTVAEDataset("jtvae_dataset.smiles", vocab=model.vocab, training=False)
collator = JTVAECollator(training=False)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=False, collate_fn=collator, drop_last=True
)

# %% [markdown]
# ## Reconstruction demo
# Reconstruct a couple of molecules to check reconstruction performance (it's not good).

# %% pycharm={"name": "#%%\n"}
acc = 0.0
device = "cuda"
for it, (tree, tree_graph, mol_graph) in enumerate(dataloader):
    if it > 10:
        break
    tot = it + 1
    smiles = tree.smiles
    tree_graph = tree_graph.to(device)
    mol_graph = mol_graph.to(device)
    dec_smiles = model.reconstruct(tree_graph, mol_graph)
    print(dec_smiles)
    print(smiles)
    print()
    if dec_smiles == smiles:
        acc += 1
print("Final acc: {:.4f}".format(acc / tot))

# %% [markdown]
# ## Generate embeddings for all LINCS + Trapnell molecules

# %% pycharm={"is_executing": true, "name": "#%%\n"}
get_data = lambda idx: collator([dataset[idx]])
errors = []
smiles = []
latents = []
for i in tqdm(range(len(dataset))):
    try:
        _, batch_tree_graphs, batch_mol_graphs = get_data(i)
        batch_tree_graphs = batch_tree_graphs.to("cuda")
        batch_mol_graphs = batch_mol_graphs.to("cuda")
        with torch.no_grad():
            _, tree_vec, mol_vec = model.encode(batch_tree_graphs, batch_mol_graphs)
        latent = torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1)
        latents.append(latent)
        smiles.append(dataset.data[i])
    except Exception as e:
        errors.append((dataset.data[i], e))

# %% pycharm={"is_executing": true, "name": "#%%\n"}
# There should only be one error, a Cl.[Li] molecule.
errors

# %% pycharm={"is_executing": true, "name": "#%%\n"}
# Add a dummy embedding for the Cl.[Li] molecule
dummy_emb = torch.mean(torch.concat(latents), dim=0).unsqueeze(dim=0)
assert dummy_emb.shape == latents[0].shape
smiles.append(errors[0][0])
latents.append(dummy_emb)
assert len(latents) == len(smiles)

# %% pycharm={"is_executing": true, "name": "#%%\n"}
np_latents = [latent.squeeze().cpu().detach().numpy() for latent in latents]
final_df = pd.DataFrame(
    np_latents,
    index=smiles,
    columns=[f"latent_{i + 1}" for i in range(np_latents[0].shape[0])],
)
final_df.to_parquet("data/jtvae_dgl.parquet")

# %% pycharm={"is_executing": true, "name": "#%%\n"}
final_df

# %% pycharm={"is_executing": true, "name": "#%%\n"}
smiles = pd.read_csv("../lincs_trapnell.smiles")
smiles2 = final_df.index

# %% pycharm={"is_executing": true, "name": "#%%\n"}
set(list(smiles["smiles"])) == set(list(smiles2))