ReactionT5 / task_yield /get_distance.py
sagawa's picture
Upload 42 files
08ccc8e verified
import argparse
import os
import sys
import warnings
import numpy as np
import pandas as pd
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import seed_everything
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser(description="Search for similar reactions.")
parser.add_argument(
"--input_data",
type=str,
required=True,
help="Path to the input data.",
)
parser.add_argument(
"--target_embedding",
type=str,
required=True,
help="Path to the target embedding.",
)
parser.add_argument(
"--query_embedding",
type=str,
required=True,
help="Path to the target embedding.",
)
parser.add_argument(
"--top_k",
type=int,
default=1,
help="Number of similar reactions to retrieve.",
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
parser.add_argument(
"--output_dir",
type=str,
default="./",
help="Directory where results are saved.",
)
return parser.parse_args()
if __name__ == "__main__":
config = parse_args()
seed_everything(42)
target_embedding = np.load(config.target_embedding)
query_embedding = np.load(config.query_embedding)
target_embedding = torch.tensor(target_embedding, dtype=torch.float32).cuda()
query_embedding = torch.tensor(query_embedding, dtype=torch.float32).cuda()
target_embedding = torch.nn.functional.normalize(target_embedding, p=2, dim=1)
query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1)
batch_size = config.batch_size
distances = []
for i in range(0, query_embedding.shape[0], batch_size):
print(f"Processing batch {i // batch_size}...")
batch = query_embedding[i : i + batch_size]
similarity = torch.matmul(batch, target_embedding.T)
distance, _ = torch.max(similarity, dim=1)
distances.append(distance.cpu().tolist())
distances = np.concatenate(distances)
df = pd.read_csv(config.input_data)
df["distance"] = distances
df.to_csv(os.path.join(config.output_dir, "distance.csv"), index=False)