""" The data splitting script for pretraining. """ import csv import os import shutil from argparse import ArgumentParser import grover.util.utils as fea_utils import numpy as np parser = ArgumentParser() parser.add_argument( "--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv" ) parser.add_argument( "--features_path", default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz", ) parser.add_argument("--sample_per_file", type=int, default=1000) parser.add_argument( "--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo" ) def load_smiles(data_path): with open(data_path) as f: reader = csv.reader(f) header = next(reader) res = [] for line in reader: res.append(line) return res, header def load_features(data_path): fea = fea_utils.load_features(data_path) return fea def save_smiles(data_path, index, data, header): fn = os.path.join(data_path, str(index) + ".csv") with open(fn, "w") as f: fw = csv.writer(f) fw.writerow(header) for d in data: fw.writerow(d) def save_features(data_path, index, data): fn = os.path.join(data_path, str(index) + ".npz") np.savez_compressed(fn, features=data) def run(): args = parser.parse_args() res, header = load_smiles(data_path=args.data_path) fea = load_features(data_path=args.features_path) assert len(res) == fea.shape[0] n_graphs = len(res) perm = np.random.permutation(n_graphs) nfold = int(n_graphs / args.sample_per_file + 1) print("Number of files: %d" % nfold) if os.path.exists(args.output_path): shutil.rmtree(args.output_path) os.makedirs(args.output_path, exist_ok=True) graph_path = os.path.join(args.output_path, "graph") fea_path = os.path.join(args.output_path, "feature") os.makedirs(graph_path, exist_ok=True) os.makedirs(fea_path, exist_ok=True) for i in range(nfold): sidx = i * args.sample_per_file eidx = min((i + 1) * args.sample_per_file, n_graphs) indexes = perm[sidx:eidx] sres = [res[j] for j in indexes] sfea = fea[indexes] save_smiles(graph_path, i, sres, header) save_features(fea_path, i, sfea) summary_path = os.path.join(args.output_path, "summary.txt") summary_fout = open(summary_path, "w") summary_fout.write("n_files:%d\n" % nfold) summary_fout.write("n_samples:%d\n" % n_graphs) summary_fout.write("sample_per_file:%d\n" % args.sample_per_file) summary_fout.close() if __name__ == "__main__": run()