github-actions[bot]
HF snapshot
a48f0ae
"""
The data splitting script for pretraining.
"""
import csv
import os
import shutil
from argparse import ArgumentParser
import grover.util.utils as fea_utils
import numpy as np
parser = ArgumentParser()
parser.add_argument(
"--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv"
)
parser.add_argument(
"--features_path",
default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz",
)
parser.add_argument("--sample_per_file", type=int, default=1000)
parser.add_argument(
"--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo"
)
def load_smiles(data_path):
with open(data_path) as f:
reader = csv.reader(f)
header = next(reader)
res = []
for line in reader:
res.append(line)
return res, header
def load_features(data_path):
fea = fea_utils.load_features(data_path)
return fea
def save_smiles(data_path, index, data, header):
fn = os.path.join(data_path, str(index) + ".csv")
with open(fn, "w") as f:
fw = csv.writer(f)
fw.writerow(header)
for d in data:
fw.writerow(d)
def save_features(data_path, index, data):
fn = os.path.join(data_path, str(index) + ".npz")
np.savez_compressed(fn, features=data)
def run():
args = parser.parse_args()
res, header = load_smiles(data_path=args.data_path)
fea = load_features(data_path=args.features_path)
assert len(res) == fea.shape[0]
n_graphs = len(res)
perm = np.random.permutation(n_graphs)
nfold = int(n_graphs / args.sample_per_file + 1)
print("Number of files: %d" % nfold)
if os.path.exists(args.output_path):
shutil.rmtree(args.output_path)
os.makedirs(args.output_path, exist_ok=True)
graph_path = os.path.join(args.output_path, "graph")
fea_path = os.path.join(args.output_path, "feature")
os.makedirs(graph_path, exist_ok=True)
os.makedirs(fea_path, exist_ok=True)
for i in range(nfold):
sidx = i * args.sample_per_file
eidx = min((i + 1) * args.sample_per_file, n_graphs)
indexes = perm[sidx:eidx]
sres = [res[j] for j in indexes]
sfea = fea[indexes]
save_smiles(graph_path, i, sres, header)
save_features(fea_path, i, sfea)
summary_path = os.path.join(args.output_path, "summary.txt")
summary_fout = open(summary_path, "w")
summary_fout.write("n_files:%d\n" % nfold)
summary_fout.write("n_samples:%d\n" % n_graphs)
summary_fout.write("sample_per_file:%d\n" % args.sample_per_file)
summary_fout.close()
if __name__ == "__main__":
run()