|
""" |
|
The data splitting script for pretraining. |
|
""" |
|
import csv |
|
import os |
|
import shutil |
|
from argparse import ArgumentParser |
|
|
|
import grover.util.utils as fea_utils |
|
import numpy as np |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv" |
|
) |
|
parser.add_argument( |
|
"--features_path", |
|
default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz", |
|
) |
|
parser.add_argument("--sample_per_file", type=int, default=1000) |
|
parser.add_argument( |
|
"--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo" |
|
) |
|
|
|
|
|
def load_smiles(data_path): |
|
with open(data_path) as f: |
|
reader = csv.reader(f) |
|
header = next(reader) |
|
res = [] |
|
for line in reader: |
|
res.append(line) |
|
return res, header |
|
|
|
|
|
def load_features(data_path): |
|
fea = fea_utils.load_features(data_path) |
|
return fea |
|
|
|
|
|
def save_smiles(data_path, index, data, header): |
|
fn = os.path.join(data_path, str(index) + ".csv") |
|
with open(fn, "w") as f: |
|
fw = csv.writer(f) |
|
fw.writerow(header) |
|
for d in data: |
|
fw.writerow(d) |
|
|
|
|
|
def save_features(data_path, index, data): |
|
fn = os.path.join(data_path, str(index) + ".npz") |
|
np.savez_compressed(fn, features=data) |
|
|
|
|
|
def run(): |
|
args = parser.parse_args() |
|
res, header = load_smiles(data_path=args.data_path) |
|
fea = load_features(data_path=args.features_path) |
|
assert len(res) == fea.shape[0] |
|
|
|
n_graphs = len(res) |
|
perm = np.random.permutation(n_graphs) |
|
|
|
nfold = int(n_graphs / args.sample_per_file + 1) |
|
print("Number of files: %d" % nfold) |
|
if os.path.exists(args.output_path): |
|
shutil.rmtree(args.output_path) |
|
os.makedirs(args.output_path, exist_ok=True) |
|
graph_path = os.path.join(args.output_path, "graph") |
|
fea_path = os.path.join(args.output_path, "feature") |
|
os.makedirs(graph_path, exist_ok=True) |
|
os.makedirs(fea_path, exist_ok=True) |
|
|
|
for i in range(nfold): |
|
sidx = i * args.sample_per_file |
|
eidx = min((i + 1) * args.sample_per_file, n_graphs) |
|
indexes = perm[sidx:eidx] |
|
sres = [res[j] for j in indexes] |
|
sfea = fea[indexes] |
|
save_smiles(graph_path, i, sres, header) |
|
save_features(fea_path, i, sfea) |
|
|
|
summary_path = os.path.join(args.output_path, "summary.txt") |
|
summary_fout = open(summary_path, "w") |
|
summary_fout.write("n_files:%d\n" % nfold) |
|
summary_fout.write("n_samples:%d\n" % n_graphs) |
|
summary_fout.write("sample_per_file:%d\n" % args.sample_per_file) |
|
summary_fout.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
run() |
|
|