|
import pickle |
|
import sys |
|
import time |
|
|
|
import numpy as np |
|
import rdkit |
|
import torch |
|
import torch.optim as optim |
|
import torch.optim.lr_scheduler as lr_scheduler |
|
from dgllife.data import JTVAEZINC, JTVAECollator, JTVAEDataset |
|
from dgllife.model import JTNNVAE |
|
from torch.utils.data import DataLoader |
|
from utils import get_timestamp, mkdir_p |
|
|
|
|
|
def main(args): |
|
print(f"{get_timestamp()}: {args}") |
|
mkdir_p(args.save_path) |
|
|
|
lg = rdkit.RDLogger.logger() |
|
lg.setLevel(rdkit.RDLogger.CRITICAL) |
|
|
|
if args.use_cpu or not torch.cuda.is_available(): |
|
device = torch.device("cpu") |
|
else: |
|
device = torch.device("cuda:0") |
|
|
|
with open(args.vocab_path, "rb") as f: |
|
vocab = pickle.load(f) |
|
if args.train_path is None: |
|
dataset = JTVAEZINC("train", vocab) |
|
else: |
|
dataset = JTVAEDataset(args.train_path, vocab, training=True) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=args.num_workers, |
|
collate_fn=JTVAECollator(training=True), |
|
drop_last=True, |
|
) |
|
|
|
model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth) |
|
if args.model_path is not None: |
|
print(f"Loading model at {args.model_path}") |
|
model.load_state_dict(torch.load(args.model_path, map_location="cpu")) |
|
else: |
|
model.reset_parameters() |
|
model = model.to(device) |
|
print( |
|
"Model #Params: {:d}K".format( |
|
sum([x.nelement() for x in model.parameters()]) // 1000 |
|
) |
|
) |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr) |
|
scheduler = lr_scheduler.ExponentialLR(optimizer, args.gamma) |
|
|
|
dur = [] |
|
t0 = time.time() |
|
for epoch in range(args.max_epoch): |
|
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0 |
|
for it, ( |
|
batch_trees, |
|
batch_tree_graphs, |
|
batch_mol_graphs, |
|
stereo_cand_batch_idx, |
|
stereo_cand_labels, |
|
batch_stereo_cand_graphs, |
|
) in enumerate(dataloader): |
|
batch_tree_graphs = batch_tree_graphs.to(device) |
|
batch_mol_graphs = batch_mol_graphs.to(device) |
|
stereo_cand_batch_idx = stereo_cand_batch_idx.to(device) |
|
batch_stereo_cand_graphs = batch_stereo_cand_graphs.to(device) |
|
|
|
loss, kl_div, wacc, tacc, sacc, dacc = model( |
|
batch_trees, |
|
batch_tree_graphs, |
|
batch_mol_graphs, |
|
stereo_cand_batch_idx, |
|
stereo_cand_labels, |
|
batch_stereo_cand_graphs, |
|
beta=args.beta, |
|
) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
word_acc += wacc |
|
topo_acc += tacc |
|
assm_acc += sacc |
|
steo_acc += dacc |
|
|
|
if (it + 1) % args.print_iter == 0: |
|
dur.append(time.time() - t0) |
|
word_acc = word_acc / args.print_iter * 100 |
|
topo_acc = topo_acc / args.print_iter * 100 |
|
assm_acc = assm_acc / args.print_iter * 100 |
|
steo_acc = steo_acc / args.print_iter * 100 |
|
|
|
print( |
|
get_timestamp(), |
|
"Epoch {:d}/{:d} | Iter {:d}/{:d} | KL: {:.1f}, Word: {:.2f}, " |
|
"Topo: {:.2f}, Assm: {:.2f}, Steo: {:.2f} | " |
|
"Estimated time per epoch: {:.4f}s".format( |
|
epoch + 1, |
|
args.max_epoch, |
|
it + 1, |
|
len(dataloader), |
|
kl_div, |
|
word_acc, |
|
topo_acc, |
|
assm_acc, |
|
steo_acc, |
|
np.mean(dur) / args.print_iter * len(dataloader), |
|
), |
|
) |
|
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0 |
|
sys.stdout.flush() |
|
t0 = time.time() |
|
|
|
if (it + 1) % 15000 == 0: |
|
scheduler.step() |
|
|
|
if (it + 1) % args.save_iter == 0: |
|
save_path = args.save_path + f"/model.epoch-{epoch}-iter-{it}" |
|
print(get_timestamp(), f"Saving checkpoint at {save_path}") |
|
torch.save(model.state_dict(), save_path) |
|
|
|
scheduler.step() |
|
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch)) |
|
|
|
return { |
|
"KL": kl_div, |
|
"Word": word_acc, |
|
"Topo": topo_acc, |
|
"Assm": assm_acc, |
|
"Steo": steo_acc, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
from argparse import ArgumentParser |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"-tr", |
|
"--train-path", |
|
type=str, |
|
help="Path to the training molecules, with one SMILES string a line", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--save-path", |
|
type=str, |
|
default="vae_model", |
|
help="Directory to save model checkpoints", |
|
) |
|
parser.add_argument( |
|
"-m", "--model-path", type=str, help="Path to pre-trained model checkpoint" |
|
) |
|
parser.add_argument("-b", "--batch-size", type=int, default=40, help="Batch size") |
|
parser.add_argument( |
|
"-w", "--hidden-size", type=int, default=450, help="Hidden size" |
|
) |
|
parser.add_argument("-l", "--latent-size", type=int, default=56, help="Latent size") |
|
parser.add_argument( |
|
"-d", "--depth", type=int, default=3, help="Number of GNN layers" |
|
) |
|
parser.add_argument( |
|
"-z", "--beta", type=float, default=0.001, help="Weight for KL loss term" |
|
) |
|
parser.add_argument("-lr", "--lr", type=float, default=0.0007, help="Learning rate") |
|
parser.add_argument( |
|
"-g", |
|
"--gamma", |
|
type=float, |
|
default=0.9, |
|
help="Multiplicative factor for learning rate decay", |
|
) |
|
parser.add_argument( |
|
"-me", |
|
"--max-epoch", |
|
type=int, |
|
default=7, |
|
help="Maximum number of epochs for training", |
|
) |
|
parser.add_argument( |
|
"-nw", |
|
"--num-workers", |
|
type=int, |
|
default=4, |
|
help="Number of subprocesses for data loading", |
|
) |
|
parser.add_argument( |
|
"-pi", |
|
"--print-iter", |
|
type=int, |
|
default=20, |
|
help="Frequency for printing evaluation metrics", |
|
) |
|
parser.add_argument( |
|
"-cpu", |
|
"--use-cpu", |
|
action="store_true", |
|
help="By default, the script uses GPU whenever available. " |
|
"This flag enforces the use of CPU.", |
|
) |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|