File size: 2,627 Bytes
a48f0ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
The data splitting script for pretraining.
"""
import csv
import os
import shutil
from argparse import ArgumentParser

import grover.util.utils as fea_utils
import numpy as np

parser = ArgumentParser()
parser.add_argument(
    "--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv"
)
parser.add_argument(
    "--features_path",
    default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz",
)
parser.add_argument("--sample_per_file", type=int, default=1000)
parser.add_argument(
    "--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo"
)


def load_smiles(data_path):
    with open(data_path) as f:
        reader = csv.reader(f)
        header = next(reader)
        res = []
        for line in reader:
            res.append(line)
    return res, header


def load_features(data_path):
    fea = fea_utils.load_features(data_path)
    return fea


def save_smiles(data_path, index, data, header):
    fn = os.path.join(data_path, str(index) + ".csv")
    with open(fn, "w") as f:
        fw = csv.writer(f)
        fw.writerow(header)
        for d in data:
            fw.writerow(d)


def save_features(data_path, index, data):
    fn = os.path.join(data_path, str(index) + ".npz")
    np.savez_compressed(fn, features=data)


def run():
    args = parser.parse_args()
    res, header = load_smiles(data_path=args.data_path)
    fea = load_features(data_path=args.features_path)
    assert len(res) == fea.shape[0]

    n_graphs = len(res)
    perm = np.random.permutation(n_graphs)

    nfold = int(n_graphs / args.sample_per_file + 1)
    print("Number of files: %d" % nfold)
    if os.path.exists(args.output_path):
        shutil.rmtree(args.output_path)
    os.makedirs(args.output_path, exist_ok=True)
    graph_path = os.path.join(args.output_path, "graph")
    fea_path = os.path.join(args.output_path, "feature")
    os.makedirs(graph_path, exist_ok=True)
    os.makedirs(fea_path, exist_ok=True)

    for i in range(nfold):
        sidx = i * args.sample_per_file
        eidx = min((i + 1) * args.sample_per_file, n_graphs)
        indexes = perm[sidx:eidx]
        sres = [res[j] for j in indexes]
        sfea = fea[indexes]
        save_smiles(graph_path, i, sres, header)
        save_features(fea_path, i, sfea)

    summary_path = os.path.join(args.output_path, "summary.txt")
    summary_fout = open(summary_path, "w")
    summary_fout.write("n_files:%d\n" % nfold)
    summary_fout.write("n_samples:%d\n" % n_graphs)
    summary_fout.write("sample_per_file:%d\n" % args.sample_per_file)
    summary_fout.close()


if __name__ == "__main__":
    run()