ReactionT5 / task_forward /calculate_accuracy.py
sagawa's picture
Upload 42 files
08ccc8e verified
import argparse
import os
import sys
import warnings
import pandas as pd
import rdkit
from rdkit import Chem
from transformers import AutoTokenizer
rdkit.RDLogger.DisableLog("rdApp.*")
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import canonicalize, seed_everything
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser(
description="Script for reaction retrosynthesis prediction."
)
parser.add_argument(
"--input_data",
type=str,
required=True,
help="Path to the input data.",
)
parser.add_argument(
"--target_data",
type=str,
required=True,
help="Path to the target data.",
)
parser.add_argument(
"--target_col",
type=str,
required=True,
help="Name of target column.",
)
parser.add_argument(
"--model_name_or_path",
type=str,
default="sagawa/ReactionT5v2-retrosynthesis",
help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.",
)
parser.add_argument(
"--num_beams", type=int, default=5, help="Number of beams used for beam search."
)
parser.add_argument(
"--seed", type=int, default=42, help="Seed for reproducibility."
)
return parser.parse_args()
def remove_space(row):
for i in range(5):
row[f"{i}th"] = row[f"{i}th"].replace(" ", "")
return row
if __name__ == "__main__":
CFG = parse_args()
seed_everything(seed=CFG.seed)
tokenizer = AutoTokenizer.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path,
return_tensors="pt",
)
df = pd.read_csv(CFG.input_data)
df[[f"{i}th" for i in range(CFG.num_beams)]] = df[
[f"{i}th" for i in range(CFG.num_beams)]
].fillna(" ")
df["target"] = pd.read_csv(CFG.target_data)[CFG.target_col].values
df = df.apply(remove_space, axis=1)
top_k_invalidity = CFG.num_beams
top1, top2, top3, top5 = [], [], [], []
invalidity = []
for idx, row in df.iterrows():
target = canonicalize(row["target"])
if canonicalize(row["0th"]) == target:
top1.append(1)
top2.append(1)
top3.append(1)
top5.append(1)
elif canonicalize(row["1th"]) == target:
top1.append(0)
top2.append(1)
top3.append(1)
top5.append(1)
elif canonicalize(row["2th"]) == target:
top1.append(0)
top2.append(0)
top3.append(1)
top5.append(1)
elif canonicalize(row["3th"]) == target:
top1.append(0)
top2.append(0)
top3.append(0)
top5.append(1)
elif canonicalize(row["4th"]) == target:
top1.append(0)
top2.append(0)
top3.append(0)
top5.append(1)
else:
top1.append(0)
top2.append(0)
top3.append(0)
top5.append(0)
input_compound = row["input"]
output = [row[f"{i}th"] for i in range(top_k_invalidity)]
inval_score = 0
for ith, out in enumerate(output):
mol = Chem.MolFromSmiles(out.rstrip("."))
if not isinstance(mol, Chem.rdchem.Mol):
inval_score += 1
invalidity.append(inval_score)
print(CFG.input_data)
print(f"Top 1 accuracy: {sum(top1) / len(top1)}")
print(f"Top 2 accuracy: {sum(top2) / len(top2)}")
print(f"Top 3 accuracy: {sum(top3) / len(top3)}")
print(f"Top 5 accuracy: {sum(top5) / len(top5)}")
print(
f"Top {top_k_invalidity} Invalidity: {sum(invalidity) / (len(invalidity) * top_k_invalidity) * 100}"
)