Spaces:
Running
Running
import argparse | |
import os | |
import sys | |
import warnings | |
import pandas as pd | |
import rdkit | |
from rdkit import Chem | |
from transformers import AutoTokenizer | |
rdkit.RDLogger.DisableLog("rdApp.*") | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from utils import canonicalize, seed_everything | |
warnings.filterwarnings("ignore") | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Script for reaction retrosynthesis prediction." | |
) | |
parser.add_argument( | |
"--input_data", | |
type=str, | |
required=True, | |
help="Path to the input data.", | |
) | |
parser.add_argument( | |
"--target_data", | |
type=str, | |
required=True, | |
help="Path to the target data.", | |
) | |
parser.add_argument( | |
"--target_col", | |
type=str, | |
required=True, | |
help="Name of target column.", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
default="sagawa/ReactionT5v2-retrosynthesis", | |
help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.", | |
) | |
parser.add_argument( | |
"--num_beams", type=int, default=5, help="Number of beams used for beam search." | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, help="Seed for reproducibility." | |
) | |
return parser.parse_args() | |
def remove_space(row): | |
for i in range(5): | |
row[f"{i}th"] = row[f"{i}th"].replace(" ", "") | |
return row | |
if __name__ == "__main__": | |
CFG = parse_args() | |
seed_everything(seed=CFG.seed) | |
tokenizer = AutoTokenizer.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path, | |
return_tensors="pt", | |
) | |
df = pd.read_csv(CFG.input_data) | |
df[[f"{i}th" for i in range(CFG.num_beams)]] = df[ | |
[f"{i}th" for i in range(CFG.num_beams)] | |
].fillna(" ") | |
df["target"] = pd.read_csv(CFG.target_data)[CFG.target_col].values | |
df = df.apply(remove_space, axis=1) | |
top_k_invalidity = CFG.num_beams | |
top1, top2, top3, top5 = [], [], [], [] | |
invalidity = [] | |
for idx, row in df.iterrows(): | |
target = canonicalize(row["target"]) | |
if canonicalize(row["0th"]) == target: | |
top1.append(1) | |
top2.append(1) | |
top3.append(1) | |
top5.append(1) | |
elif canonicalize(row["1th"]) == target: | |
top1.append(0) | |
top2.append(1) | |
top3.append(1) | |
top5.append(1) | |
elif canonicalize(row["2th"]) == target: | |
top1.append(0) | |
top2.append(0) | |
top3.append(1) | |
top5.append(1) | |
elif canonicalize(row["3th"]) == target: | |
top1.append(0) | |
top2.append(0) | |
top3.append(0) | |
top5.append(1) | |
elif canonicalize(row["4th"]) == target: | |
top1.append(0) | |
top2.append(0) | |
top3.append(0) | |
top5.append(1) | |
else: | |
top1.append(0) | |
top2.append(0) | |
top3.append(0) | |
top5.append(0) | |
input_compound = row["input"] | |
output = [row[f"{i}th"] for i in range(top_k_invalidity)] | |
inval_score = 0 | |
for ith, out in enumerate(output): | |
mol = Chem.MolFromSmiles(out.rstrip(".")) | |
if not isinstance(mol, Chem.rdchem.Mol): | |
inval_score += 1 | |
invalidity.append(inval_score) | |
print(CFG.input_data) | |
print(f"Top 1 accuracy: {sum(top1) / len(top1)}") | |
print(f"Top 2 accuracy: {sum(top2) / len(top2)}") | |
print(f"Top 3 accuracy: {sum(top3) / len(top3)}") | |
print(f"Top 5 accuracy: {sum(top5) / len(top5)}") | |
print( | |
f"Top {top_k_invalidity} Invalidity: {sum(invalidity) / (len(invalidity) * top_k_invalidity) * 100}" | |
) | |